blob: df4d6c383abe6ed73d1b2f70100a06375f460aa6 [file] [log] [blame]
//===- ExtraRematTest.cpp - Coroutines unit tests -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/Module.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Testing/Support/Error.h"
#include "llvm/Transforms/Coroutines/CoroSplit.h"
#include "gtest/gtest.h"
using namespace llvm;
namespace {
struct ExtraRematTest : public testing::Test {
LLVMContext Ctx;
ModulePassManager MPM;
PassBuilder PB;
LoopAnalysisManager LAM;
FunctionAnalysisManager FAM;
CGSCCAnalysisManager CGAM;
ModuleAnalysisManager MAM;
LLVMContext Context;
std::unique_ptr<Module> M;
ExtraRematTest() {
PB.registerModuleAnalyses(MAM);
PB.registerCGSCCAnalyses(CGAM);
PB.registerFunctionAnalyses(FAM);
PB.registerLoopAnalyses(LAM);
PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
}
BasicBlock *getBasicBlockByName(Function *F, StringRef Name) const {
for (BasicBlock &BB : *F) {
if (BB.getName() == Name)
return &BB;
}
return nullptr;
}
CallInst *getCallByName(BasicBlock *BB, StringRef Name) const {
for (Instruction &I : *BB) {
if (CallInst *CI = dyn_cast<CallInst>(&I))
if (CI->getCalledFunction()->getName() == Name)
return CI;
}
return nullptr;
}
void ParseAssembly(const StringRef IR) {
SMDiagnostic Error;
M = parseAssemblyString(IR, Error, Context);
std::string errMsg;
raw_string_ostream os(errMsg);
Error.print("", os);
// A failure here means that the test itself is buggy.
if (!M)
report_fatal_error(os.str().c_str());
}
};
StringRef Text = R"(
define ptr @f(i32 %n) presplitcoroutine {
entry:
%id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null)
%size = call i32 @llvm.coro.size.i32()
%alloc = call ptr @malloc(i32 %size)
%hdl = call ptr @llvm.coro.begin(token %id, ptr %alloc)
%inc1 = add i32 %n, 1
%val2 = call i32 @should.remat(i32 %inc1)
%sp1 = call i8 @llvm.coro.suspend(token none, i1 false)
switch i8 %sp1, label %suspend [i8 0, label %resume1
i8 1, label %cleanup]
resume1:
%inc2 = add i32 %val2, 1
%sp2 = call i8 @llvm.coro.suspend(token none, i1 false)
switch i8 %sp1, label %suspend [i8 0, label %resume2
i8 1, label %cleanup]
resume2:
call void @print(i32 %val2)
call void @print(i32 %inc2)
br label %cleanup
cleanup:
%mem = call ptr @llvm.coro.free(token %id, ptr %hdl)
call void @free(ptr %mem)
br label %suspend
suspend:
call i1 @llvm.coro.end(ptr %hdl, i1 0)
ret ptr %hdl
}
declare ptr @llvm.coro.free(token, ptr)
declare i32 @llvm.coro.size.i32()
declare i8 @llvm.coro.suspend(token, i1)
declare void @llvm.coro.resume(ptr)
declare void @llvm.coro.destroy(ptr)
declare token @llvm.coro.id(i32, ptr, ptr, ptr)
declare i1 @llvm.coro.alloc(token)
declare ptr @llvm.coro.begin(token, ptr)
declare i1 @llvm.coro.end(ptr, i1)
declare i32 @should.remat(i32)
declare noalias ptr @malloc(i32)
declare void @print(i32)
declare void @free(ptr)
)";
// Materializable callback with extra rematerialization
bool ExtraMaterializable(Instruction &I) {
if (isa<CastInst>(&I) || isa<GetElementPtrInst>(&I) ||
isa<BinaryOperator>(&I) || isa<CmpInst>(&I) || isa<SelectInst>(&I))
return true;
if (auto *CI = dyn_cast<CallInst>(&I)) {
auto *CalledFunc = CI->getCalledFunction();
if (CalledFunc && CalledFunc->getName().startswith("should.remat"))
return true;
}
return false;
}
TEST_F(ExtraRematTest, TestCoroRematDefault) {
ParseAssembly(Text);
ASSERT_TRUE(M);
CGSCCPassManager CGPM;
CGPM.addPass(CoroSplitPass());
MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM)));
MPM.run(*M, MAM);
// Verify that extra rematerializable instruction has been rematerialized
Function *F = M->getFunction("f.resume");
ASSERT_TRUE(F) << "could not find split function f.resume";
BasicBlock *Resume1 = getBasicBlockByName(F, "resume1");
ASSERT_TRUE(Resume1)
<< "could not find expected BB resume1 in split function";
// With default materialization the intrinsic should not have been
// rematerialized
CallInst *CI = getCallByName(Resume1, "should.remat");
ASSERT_FALSE(CI);
}
TEST_F(ExtraRematTest, TestCoroRematWithCallback) {
ParseAssembly(Text);
ASSERT_TRUE(M);
CGSCCPassManager CGPM;
CGPM.addPass(
CoroSplitPass(std::function<bool(Instruction &)>(ExtraMaterializable)));
MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM)));
MPM.run(*M, MAM);
// Verify that extra rematerializable instruction has been rematerialized
Function *F = M->getFunction("f.resume");
ASSERT_TRUE(F) << "could not find split function f.resume";
BasicBlock *Resume1 = getBasicBlockByName(F, "resume1");
ASSERT_TRUE(Resume1)
<< "could not find expected BB resume1 in split function";
// With callback the extra rematerialization of the function should have
// happened
CallInst *CI = getCallByName(Resume1, "should.remat");
ASSERT_TRUE(CI);
}
} // namespace