mach-detours/tests/test_threads.cpp

155 lines
4.9 KiB
C++

// Copyright (c) Lysann Tranvouez. All rights reserved.
#include <algorithm>
#include <catch2/catch_test_macros.hpp>
#include <mach_detours.h>
#include <pthread.h>
#include <unistd.h>
static const char* callee(int a, int b, int c)
{
return "localFunction";
}
static int g_b = 6;
static int g_c = 3;
__attribute__((target("branch-protection=pac-ret+bti")))
static const char* localFunction()
{
// This is just some "complicated" code to ensure localFunction gets the branch protection instructions (needs a callee)
int a = 53;
int b = g_b++;
auto ret = callee(a, b, 17);
g_c = g_b;
return ret;
}
static const char* localFunctionDetour()
{
return "localFunctionDetour";
}
static const char* (*realLocalFunction)() = localFunction;
struct threadData
{
std::atomic<int> loopCounter;
bool cancel = false;
};
static void* threadFunc(void* pUserArg)
{
auto& data = *static_cast<threadData*>(pUserArg);
while (!data.cancel) {
localFunction();
localFunction();
localFunction();
localFunction();
localFunction();
++data.loopCounter;
}
return nullptr;
}
template<size_t numThreads>
bool resetLoopCountersAndWaitForLoop(std::array<threadData, numThreads>& threadDatas, std::optional<std::chrono::seconds> timeout = std::nullopt)
{
for (auto i = 0; i < numThreads; i++) {
threadDatas[i].loopCounter = 0;
}
const auto endTime = timeout ? std::make_optional(std::chrono::steady_clock::now() + *timeout) : std::nullopt;
while (true) {
// we check for >= 2 to be sure we did a full loop - we don't want to execute just the line incrementing the counter!
const bool anyLooped = std::ranges::any_of(threadDatas, [](const threadData& data){ return data.loopCounter >= 2; });
if (anyLooped) {
return true;
}
if (endTime && std::chrono::steady_clock::now() > *endTime) {
return false;
}
usleep( 100 );
}
}
TEST_CASE( "Handling other threads", "[attach][local][threads]" )
{
using namespace std::chrono_literals;
constexpr auto numThreads = 150;
pthread_t threads[numThreads];
std::array<threadData, numThreads> threadDatas{};
for (auto i = 0; i < numThreads; i++) {
pthread_create(&threads[i], nullptr, threadFunc, (void*)&threadDatas[i]);
}
SECTION( "running threads modify a value" )
{
const int saved_b = g_b;
resetLoopCountersAndWaitForLoop(threadDatas);
CHECK( saved_b != g_b );
}
SECTION( "registered threads are suspended during transaction and resumed afterwards" )
{
SECTION( "manual thread managing" )
{
CHECK( detour_transaction_begin() == err_none );
for (auto&& thread : threads) {
CHECK( detour_manage_thread( pthread_mach_thread_np(thread) ) == err_none );
}
}
SECTION( "manage_all_threads" )
{
CHECK( detour_transaction_begin() == err_none );
CHECK( detour_manage_all_threads() == err_none );
}
SECTION( "transaction_begin_managed" )
{
CHECK( detour_transaction_begin_managed() == err_none );
}
int saved_b = g_b;
resetLoopCountersAndWaitForLoop(threadDatas, 1s);
CHECK( saved_b == g_b );
CHECK( detour_transaction_commit() == err_none );
saved_b = g_b;
resetLoopCountersAndWaitForLoop(threadDatas);
CHECK( saved_b != g_b );
}
SECTION( "aborting a transaction resumes too" )
{
CHECK( detour_transaction_begin_managed() == err_none );
CHECK( detour_transaction_abort() == err_none );
const int saved_b = g_b;
resetLoopCountersAndWaitForLoop(threadDatas);
CHECK( saved_b != g_b );
}
SECTION( "thread's PC gets redirected on commit, if it was in an overridden area" )
{
// Note that this test case doesn't guarantee that any thread is in the section of code that gets moved into the
// trampoline. So if buggy it could be flaky. You can try using more threads to get a more reliable repro.
CHECK( detour_transaction_begin_managed() == err_none );
CHECK( detour_attach_and_commit(reinterpret_cast<detour_func_t*>(&realLocalFunction), reinterpret_cast<detour_func_t>(localFunctionDetour)) == err_none );
const int saved_b = g_b;
resetLoopCountersAndWaitForLoop(threadDatas);
CHECK( saved_b == g_b );
// clean up
CHECK( detour_transaction_begin_managed() == err_none );
CHECK( detour_detach_and_commit(reinterpret_cast<detour_func_t*>(&realLocalFunction), reinterpret_cast<detour_func_t>(localFunctionDetour)) == err_none );
}
for (auto i = 0; i < numThreads; i++) {
threadDatas[i].cancel = true;
}
for (auto i = 0; i < numThreads; i++) {
pthread_join(threads[i], nullptr);
}
}