blob: 1abe373180695619c3d8be4788050735f55ef38a [file] [log] [blame]
// Copyright 2019 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "base/task/thread_pool/job_task_source_interface.h"
#include <utility>
#include "base/functional/callback_helpers.h"
#include "base/memory/ptr_util.h"
#include "base/task/post_job.h"
#include "base/task/task_features.h"
#include "base/task/thread_pool/pooled_task_runner_delegate.h"
#include "base/task/thread_pool/test_utils.h"
#include "base/test/bind.h"
#include "base/test/gtest_util.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/test_timeouts.h"
#include "build/build_config.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
using ::testing::_;
using ::testing::Return;
namespace base {
namespace internal {
class MockPooledTaskRunnerDelegate : public PooledTaskRunnerDelegate {
public:
MOCK_METHOD2(PostTaskWithSequence,
bool(Task task, scoped_refptr<Sequence> sequence));
MOCK_METHOD1(ShouldYield, bool(const TaskSource* task_source));
MOCK_METHOD1(EnqueueJobTaskSource,
bool(scoped_refptr<JobTaskSource> task_source));
MOCK_METHOD1(RemoveJobTaskSource,
void(scoped_refptr<JobTaskSource> task_source));
MOCK_CONST_METHOD1(IsRunningPoolWithTraits, bool(const TaskTraits& traits));
MOCK_METHOD2(UpdatePriority,
void(scoped_refptr<TaskSource> task_source,
TaskPriority priority));
MOCK_METHOD2(UpdateJobPriority,
void(scoped_refptr<TaskSource> task_source,
TaskPriority priority));
};
class ThreadPoolJobTaskSourceTest : public testing::Test,
public testing::WithParamInterface<bool> {
protected:
ThreadPoolJobTaskSourceTest() {
if (GetParam()) {
scoped_feature_list_.InitAndEnableFeature(kUseNewJobImplementation);
} else {
scoped_feature_list_.InitAndDisableFeature(kUseNewJobImplementation);
}
}
// Creates and starts a job which which runs `callback` with
// `initial_max_concurrency`.
std::pair<scoped_refptr<test::MockJobTask>, scoped_refptr<JobTaskSource>>
StartJob(size_t initial_max_concurrency,
base::RepeatingCallback<void(JobDelegate*)> callback = DoNothing()) {
auto job_task = base::MakeRefCounted<test::MockJobTask>(
std::move(callback), /* num_tasks_to_run */ initial_max_concurrency);
auto task_source = job_task->GetJobTaskSource(
FROM_HERE, {}, &pooled_task_runner_delegate_);
if (initial_max_concurrency > 0) {
EXPECT_CALL(pooled_task_runner_delegate_, EnqueueJobTaskSource(_));
}
task_source->NotifyConcurrencyIncrease();
return {job_task, task_source};
}
testing::StrictMock<MockPooledTaskRunnerDelegate>
pooled_task_runner_delegate_;
private:
base::test::ScopedFeatureList scoped_feature_list_;
};
// Verifies the normal flow of running 2 tasks one after the other.
TEST_P(ThreadPoolJobTaskSourceTest, RunTasks) {
auto [job_task, task_source] = StartJob(/* initial_max_concurrency=*/2);
auto registered_task_source =
RegisteredTaskSource::CreateForTesting(task_source);
EXPECT_EQ(2U, task_source->GetRemainingConcurrency());
{
EXPECT_EQ(registered_task_source.WillRunTask(),
TaskSource::RunStatus::kAllowedNotSaturated);
EXPECT_EQ(1U, task_source->GetWorkerCount());
auto task = registered_task_source.TakeTask();
std::move(task.task).Run();
EXPECT_TRUE(registered_task_source.DidProcessTask());
EXPECT_EQ(0U, task_source->GetWorkerCount());
}
{
EXPECT_EQ(registered_task_source.WillRunTask(),
TaskSource::RunStatus::kAllowedSaturated);
EXPECT_EQ(1U, task_source->GetWorkerCount());
// An attempt to run an additional task is not allowed.
EXPECT_EQ(RegisteredTaskSource::CreateForTesting(task_source).WillRunTask(),
TaskSource::RunStatus::kDisallowed);
EXPECT_EQ(0U, task_source->GetRemainingConcurrency());
auto task = registered_task_source.TakeTask();
EXPECT_EQ(RegisteredTaskSource::CreateForTesting(task_source).WillRunTask(),
TaskSource::RunStatus::kDisallowed);
std::move(task.task).Run();
EXPECT_EQ(0U, task_source->GetRemainingConcurrency());
EXPECT_TRUE(task_source->IsActive());
// Returns false because the task source is out of tasks.
EXPECT_FALSE(registered_task_source.DidProcessTask());
EXPECT_EQ(0U, task_source->GetWorkerCount());
EXPECT_FALSE(task_source->IsActive());
}
}
// Verifies that a job task source doesn't allow any new RunStatus after Clear()
// is called.
TEST_P(ThreadPoolJobTaskSourceTest, Clear) {
auto [job_task, task_source] = StartJob(/* initial_max_concurrency=*/5);
EXPECT_EQ(5U, task_source->GetRemainingConcurrency());
auto registered_task_source_a =
RegisteredTaskSource::CreateForTesting(task_source);
EXPECT_EQ(registered_task_source_a.WillRunTask(),
TaskSource::RunStatus::kAllowedNotSaturated);
auto task_a = registered_task_source_a.TakeTask();
auto registered_task_source_b =
RegisteredTaskSource::CreateForTesting(task_source);
EXPECT_EQ(registered_task_source_b.WillRunTask(),
TaskSource::RunStatus::kAllowedNotSaturated);
auto registered_task_source_c =
RegisteredTaskSource::CreateForTesting(task_source);
EXPECT_EQ(registered_task_source_c.WillRunTask(),
TaskSource::RunStatus::kAllowedNotSaturated);
auto registered_task_source_d =
RegisteredTaskSource::CreateForTesting(task_source);
EXPECT_EQ(registered_task_source_d.WillRunTask(),
TaskSource::RunStatus::kAllowedNotSaturated);
EXPECT_FALSE(task_source->ShouldYield());
{
EXPECT_EQ(1U, task_source->GetRemainingConcurrency());
auto task = registered_task_source_c.Clear();
EXPECT_FALSE(task);
registered_task_source_c.DidProcessTask();
EXPECT_EQ(0U, task_source->GetRemainingConcurrency());
}
// The task source shouldn't allow any further tasks after Clear.
EXPECT_TRUE(task_source->ShouldYield());
EXPECT_EQ(RegisteredTaskSource::CreateForTesting(task_source).WillRunTask(),
TaskSource::RunStatus::kDisallowed);
// Another outstanding RunStatus can still call Clear.
{
auto task = registered_task_source_d.Clear();
EXPECT_FALSE(task);
registered_task_source_d.DidProcessTask();
EXPECT_EQ(0U, task_source->GetRemainingConcurrency());
}
// A task that was already acquired can still run.
std::move(task_a.task).Run();
registered_task_source_a.DidProcessTask();
// A valid outstanding RunStatus can also take and run a task.
{
auto task = registered_task_source_b.TakeTask();
std::move(task.task).Run();
registered_task_source_b.DidProcessTask();
}
}
// Verifies that a job task source doesn't return an "allowed" RunStatus after
// Cancel() is called.
TEST_P(ThreadPoolJobTaskSourceTest, Cancel) {
auto [job_task, task_source] = StartJob(/* initial_max_concurrency=*/3);
auto registered_task_source_a =
RegisteredTaskSource::CreateForTesting(task_source);
EXPECT_EQ(registered_task_source_a.WillRunTask(),
TaskSource::RunStatus::kAllowedNotSaturated);
auto task_a = registered_task_source_a.TakeTask();
auto registered_task_source_b =
RegisteredTaskSource::CreateForTesting(task_source);
EXPECT_EQ(registered_task_source_b.WillRunTask(),
TaskSource::RunStatus::kAllowedNotSaturated);
EXPECT_FALSE(task_source->ShouldYield());
task_source->Cancel();
EXPECT_TRUE(task_source->ShouldYield());
// The task source shouldn't allow any further tasks after Cancel.
EXPECT_EQ(RegisteredTaskSource::CreateForTesting(task_source).WillRunTask(),
TaskSource::RunStatus::kDisallowed);
// A task that was already acquired can still run.
std::move(task_a.task).Run();
registered_task_source_a.DidProcessTask();
// A RegisteredTaskSource that's ready can also take and run a task.
{
auto task = registered_task_source_b.TakeTask();
std::move(task.task).Run();
registered_task_source_b.DidProcessTask();
}
}
// Verifies that multiple tasks can run in parallel up to |max_concurrency|.
TEST_P(ThreadPoolJobTaskSourceTest, RunTasksInParallel) {
auto [job_task, task_source] = StartJob(/* initial_max_concurrency=*/2);
auto registered_task_source_a =
RegisteredTaskSource::CreateForTesting(task_source);
EXPECT_EQ(registered_task_source_a.WillRunTask(),
TaskSource::RunStatus::kAllowedNotSaturated);
EXPECT_EQ(1U, task_source->GetWorkerCount());
EXPECT_EQ(1U, task_source->GetSortKey().worker_count());
auto task_a = registered_task_source_a.TakeTask();
auto registered_task_source_b =
RegisteredTaskSource::CreateForTesting(task_source);
EXPECT_EQ(registered_task_source_b.WillRunTask(),
TaskSource::RunStatus::kAllowedSaturated);
EXPECT_EQ(2U, task_source->GetWorkerCount());
EXPECT_EQ(2U, task_source->GetSortKey().worker_count());
auto task_b = registered_task_source_b.TakeTask();
// WillRunTask() should return a null RunStatus once the max concurrency is
// reached.
EXPECT_EQ(RegisteredTaskSource::CreateForTesting(task_source).WillRunTask(),
TaskSource::RunStatus::kDisallowed);
std::move(task_a.task).Run();
EXPECT_FALSE(registered_task_source_a.DidProcessTask());
EXPECT_EQ(1U, task_source->GetSortKey().worker_count());
// Increasing max concurrency above the number of workers should cause the
// task source to re-enqueue.
job_task->SetNumTasksToRun(2);
EXPECT_CALL(pooled_task_runner_delegate_, EnqueueJobTaskSource(_));
task_source->NotifyConcurrencyIncrease();
std::move(task_b.task).Run();
EXPECT_TRUE(registered_task_source_b.DidProcessTask());
EXPECT_EQ(0U, task_source->GetSortKey().worker_count());
EXPECT_EQ(0U, task_source->GetWorkerCount());
auto registered_task_source_c =
RegisteredTaskSource::CreateForTesting(task_source);
EXPECT_EQ(registered_task_source_c.WillRunTask(),
TaskSource::RunStatus::kAllowedSaturated);
auto task_c = registered_task_source_c.TakeTask();
std::move(task_c.task).Run();
EXPECT_FALSE(registered_task_source_c.DidProcessTask());
}
// Verifies the normal flow of running the join task until completion.
TEST_P(ThreadPoolJobTaskSourceTest, RunJoinTask) {
auto job_task = base::MakeRefCounted<test::MockJobTask>(
DoNothing(), /* num_tasks_to_run */ 2);
scoped_refptr<JobTaskSource> task_source =
job_task->GetJobTaskSource(FROM_HERE, {}, &pooled_task_runner_delegate_);
EXPECT_TRUE(task_source->WillJoin());
// Intentionally run |worker_task| twice to make sure RunJoinTask() calls
// it again. This can happen in production if the joining thread spuriously
// return and needs to run again.
EXPECT_TRUE(task_source->RunJoinTask());
EXPECT_FALSE(task_source->RunJoinTask());
}
// Verify that |worker_count| excludes the (inactive) returning thread calling
// max_concurrency_callback.
TEST_P(ThreadPoolJobTaskSourceTest, RunTaskWorkerCount) {
size_t max_concurrency = 1;
scoped_refptr<JobTaskSource> task_source = internal::CreateJobTaskSource(
FROM_HERE, TaskTraits(),
BindLambdaForTesting([&](JobDelegate* delegate) { --max_concurrency; }),
BindLambdaForTesting([&](size_t worker_count) -> size_t {
return max_concurrency + worker_count;
}),
&pooled_task_runner_delegate_);
EXPECT_CALL(pooled_task_runner_delegate_, EnqueueJobTaskSource(_));
task_source->NotifyConcurrencyIncrease();
auto registered_task_source =
RegisteredTaskSource::CreateForTesting(task_source);
EXPECT_EQ(registered_task_source.WillRunTask(),
TaskSource::RunStatus::kAllowedSaturated);
auto task = registered_task_source.TakeTask();
std::move(task.task).Run();
// Once the worker_task runs, |worker_count| should drop to 0 and the job
// should finish.
EXPECT_FALSE(registered_task_source.DidProcessTask());
EXPECT_EQ(0U, max_concurrency);
}
// Verify that |worker_count| excludes the (inactive) joining thread calling
// max_concurrency_callback.
TEST_P(ThreadPoolJobTaskSourceTest, RunJoinTaskWorkerCount) {
size_t max_concurrency = 1;
scoped_refptr<JobTaskSource> task_source = internal::CreateJobTaskSource(
FROM_HERE, TaskTraits(),
BindLambdaForTesting([&](JobDelegate* delegate) { --max_concurrency; }),
BindLambdaForTesting([&](size_t worker_count) -> size_t {
return max_concurrency + worker_count;
}),
&pooled_task_runner_delegate_);
EXPECT_TRUE(task_source->WillJoin());
// Once the worker_task runs, |worker_count| should drop to 0 and the job
// should finish.
EXPECT_FALSE(task_source->RunJoinTask());
EXPECT_EQ(0U, max_concurrency);
}
// Verifies that WillJoin() doesn't allow a joining thread to contribute
// after Cancel() is called.
TEST_P(ThreadPoolJobTaskSourceTest, CancelJoinTask) {
auto job_task = base::MakeRefCounted<test::MockJobTask>(
DoNothing(), /* num_tasks_to_run */ 2);
scoped_refptr<JobTaskSource> task_source =
job_task->GetJobTaskSource(FROM_HERE, {}, &pooled_task_runner_delegate_);
task_source->Cancel();
EXPECT_FALSE(task_source->WillJoin());
}
// Verifies that RunJoinTask() doesn't allow a joining thread to contribute
// after Cancel() is called.
TEST_P(ThreadPoolJobTaskSourceTest, JoinCancelTask) {
auto job_task = base::MakeRefCounted<test::MockJobTask>(
DoNothing(), /* num_tasks_to_run */ 2);
scoped_refptr<JobTaskSource> task_source =
job_task->GetJobTaskSource(FROM_HERE, {}, &pooled_task_runner_delegate_);
EXPECT_TRUE(task_source->WillJoin());
task_source->Cancel();
EXPECT_FALSE(task_source->RunJoinTask());
}
// Verifies that the join task can run in parallel with worker tasks up to
// |max_concurrency|.
TEST_P(ThreadPoolJobTaskSourceTest, RunJoinTaskInParallel) {
auto [job_task, task_source] = StartJob(/* initial_max_concurrency=*/2);
auto registered_task_source =
RegisteredTaskSource::CreateForTesting(task_source);
EXPECT_EQ(registered_task_source.WillRunTask(),
TaskSource::RunStatus::kAllowedNotSaturated);
auto worker_task = registered_task_source.TakeTask();
EXPECT_TRUE(task_source->WillJoin());
EXPECT_TRUE(task_source->IsActive());
std::move(worker_task.task).Run();
EXPECT_FALSE(registered_task_source.DidProcessTask());
EXPECT_FALSE(task_source->RunJoinTask());
EXPECT_FALSE(task_source->IsActive());
}
// Verifies that a call to NotifyConcurrencyIncrease() calls the delegate
// and allows to run additional tasks.
TEST_P(ThreadPoolJobTaskSourceTest, NotifyConcurrencyIncrease) {
auto [job_task, task_source] = StartJob(/* initial_max_concurrency=*/1);
auto registered_task_source_a =
RegisteredTaskSource::CreateForTesting(task_source);
EXPECT_EQ(registered_task_source_a.WillRunTask(),
TaskSource::RunStatus::kAllowedSaturated);
auto task_a = registered_task_source_a.TakeTask();
EXPECT_EQ(RegisteredTaskSource::CreateForTesting(task_source).WillRunTask(),
TaskSource::RunStatus::kDisallowed);
job_task->SetNumTasksToRun(2);
EXPECT_CALL(pooled_task_runner_delegate_, EnqueueJobTaskSource(_)).Times(1);
task_source->NotifyConcurrencyIncrease();
auto registered_task_source_b =
RegisteredTaskSource::CreateForTesting(task_source);
// WillRunTask() should return a valid RunStatus because max concurrency was
// increased to 2.
EXPECT_EQ(registered_task_source_b.WillRunTask(),
TaskSource::RunStatus::kAllowedSaturated);
auto task_b = registered_task_source_b.TakeTask();
EXPECT_EQ(RegisteredTaskSource::CreateForTesting(task_source).WillRunTask(),
TaskSource::RunStatus::kDisallowed);
std::move(task_a.task).Run();
EXPECT_FALSE(registered_task_source_a.DidProcessTask());
std::move(task_b.task).Run();
EXPECT_FALSE(registered_task_source_b.DidProcessTask());
}
// Verifies that ShouldYield() calls the delegate.
TEST_P(ThreadPoolJobTaskSourceTest, ShouldYield) {
auto [job_task, task_source] = StartJob(
/*initial_max_concurrency=*/1,
BindLambdaForTesting([](JobDelegate* delegate) {
// As set up below, the mock will return false once and true the second
// time.
EXPECT_FALSE(delegate->ShouldYield());
EXPECT_TRUE(delegate->ShouldYield());
}));
auto registered_task_source =
RegisteredTaskSource::CreateForTesting(task_source);
ASSERT_EQ(registered_task_source.WillRunTask(),
TaskSource::RunStatus::kAllowedSaturated);
auto task = registered_task_source.TakeTask();
EXPECT_CALL(pooled_task_runner_delegate_, ShouldYield(_))
.Times(2)
.WillOnce(Return(false))
.WillOnce(Return(true));
std::move(task.task).Run();
EXPECT_FALSE(registered_task_source.DidProcessTask());
}
// Verifies that max concurrency is allowed to stagnate when ShouldYield returns
// true.
TEST_P(ThreadPoolJobTaskSourceTest, MaxConcurrencyStagnateIfShouldYield) {
scoped_refptr<JobTaskSource> task_source = internal::CreateJobTaskSource(
FROM_HERE, TaskTraits(), BindRepeating([](JobDelegate* delegate) {
// As set up below, the mock will return true once.
ASSERT_TRUE(delegate->ShouldYield());
}),
BindRepeating([](size_t /*worker_count*/) -> size_t {
return 1; // max concurrency is always 1.
}),
&pooled_task_runner_delegate_);
EXPECT_CALL(pooled_task_runner_delegate_, EnqueueJobTaskSource(_));
task_source->NotifyConcurrencyIncrease();
EXPECT_CALL(pooled_task_runner_delegate_, ShouldYield(_))
.WillOnce(Return(true));
auto registered_task_source =
RegisteredTaskSource::CreateForTesting(task_source);
ASSERT_EQ(registered_task_source.WillRunTask(),
TaskSource::RunStatus::kAllowedSaturated);
auto task = registered_task_source.TakeTask();
// Running the task should not fail even though max concurrency remained at 1,
// since ShouldYield() returned true.
std::move(task.task).Run();
registered_task_source.DidProcessTask();
}
TEST_P(ThreadPoolJobTaskSourceTest, InvalidTakeTask) {
auto [job_task, task_source] = StartJob(/* initial_max_concurrency=*/1);
auto registered_task_source_a =
RegisteredTaskSource::CreateForTesting(task_source);
EXPECT_EQ(registered_task_source_a.WillRunTask(),
TaskSource::RunStatus::kAllowedSaturated);
auto registered_task_source_b =
RegisteredTaskSource::CreateForTesting(task_source);
EXPECT_EQ(registered_task_source_b.WillRunTask(),
TaskSource::RunStatus::kDisallowed);
// Can not be called with an invalid RunStatus.
EXPECT_DCHECK_DEATH({ auto task = registered_task_source_b.TakeTask(); });
auto task = registered_task_source_a.TakeTask();
registered_task_source_a.DidProcessTask();
}
TEST_P(ThreadPoolJobTaskSourceTest, InvalidDidProcessTask) {
auto job_task =
base::MakeRefCounted<test::MockJobTask>(DoNothing(),
/* num_tasks_to_run */ 1);
scoped_refptr<JobTaskSource> task_source =
job_task->GetJobTaskSource(FROM_HERE, {}, &pooled_task_runner_delegate_);
auto registered_task_source =
RegisteredTaskSource::CreateForTesting(task_source);
// Can not be called before WillRunTask().
EXPECT_DCHECK_DEATH(registered_task_source.DidProcessTask());
}
TEST_P(ThreadPoolJobTaskSourceTest, AcquireTaskId) {
auto job_task =
base::MakeRefCounted<test::MockJobTask>(DoNothing(),
/* num_tasks_to_run */ 4);
scoped_refptr<JobTaskSource> task_source =
job_task->GetJobTaskSource(FROM_HERE, {}, &pooled_task_runner_delegate_);
EXPECT_EQ(0U, task_source->AcquireTaskId());
EXPECT_EQ(1U, task_source->AcquireTaskId());
EXPECT_EQ(2U, task_source->AcquireTaskId());
EXPECT_EQ(3U, task_source->AcquireTaskId());
EXPECT_EQ(4U, task_source->AcquireTaskId());
task_source->ReleaseTaskId(1);
task_source->ReleaseTaskId(3);
EXPECT_EQ(1U, task_source->AcquireTaskId());
EXPECT_EQ(3U, task_source->AcquireTaskId());
EXPECT_EQ(5U, task_source->AcquireTaskId());
}
// Verifies that task id is released after worker_task returns.
TEST_P(ThreadPoolJobTaskSourceTest, GetTaskId) {
auto [job_task, task_source] = StartJob(
/* initial_max_concurrency=*/2, BindRepeating([](JobDelegate* delegate) {
// Confirm that task id 0 is reused on the second run.
EXPECT_EQ(0U, delegate->GetTaskId());
}));
auto registered_task_source =
RegisteredTaskSource::CreateForTesting(task_source);
// Run the worker_task twice.
ASSERT_EQ(registered_task_source.WillRunTask(),
TaskSource::RunStatus::kAllowedNotSaturated);
auto task1 = registered_task_source.TakeTask();
std::move(task1.task).Run();
registered_task_source.DidProcessTask();
ASSERT_EQ(registered_task_source.WillRunTask(),
TaskSource::RunStatus::kAllowedSaturated);
auto task2 = registered_task_source.TakeTask();
std::move(task2.task).Run();
registered_task_source.DidProcessTask();
}
INSTANTIATE_TEST_SUITE_P(,
ThreadPoolJobTaskSourceTest,
testing::Bool(),
[](const testing::TestParamInfo<bool>& info) {
if (info.param) {
return "NewJob";
}
return "OldJob";
});
} // namespace internal
} // namespace base