// Copyright (c) Lysann Tranvouez. All rights reserved. #include #include #include #include #include static const char* callee(int a, int b, int c) { return "localFunction"; } static int g_b = 6; static int g_c = 3; __attribute__((target("branch-protection=pac-ret+bti"))) static const char* localFunction() { // This is just some "complicated" code to ensure localFunction gets the branch protection instructions (needs a callee) int a = 53; int b = g_b++; auto ret = callee(a, b, 17); g_c = g_b; return ret; } static const char* localFunctionDetour() { return "localFunctionDetour"; } static const char* (*realLocalFunction)() = localFunction; struct threadData { std::atomic loopCounter; bool cancel = false; }; static void* threadFunc(void* pUserArg) { auto& data = *static_cast(pUserArg); while (!data.cancel) { localFunction(); localFunction(); localFunction(); localFunction(); localFunction(); ++data.loopCounter; } return nullptr; } template bool resetLoopCountersAndWaitForLoop(std::array& threadDatas, std::optional timeout = std::nullopt) { for (auto i = 0; i < numThreads; i++) { threadDatas[i].loopCounter = 0; } const auto endTime = timeout ? std::make_optional(std::chrono::steady_clock::now() + *timeout) : std::nullopt; while (true) { // we check for >= 2 to be sure we did a full loop - we don't want to execute just the line incrementing the counter! const bool anyLooped = std::ranges::any_of(threadDatas, [](const threadData& data){ return data.loopCounter >= 2; }); if (anyLooped) { return true; } if (endTime && std::chrono::steady_clock::now() > *endTime) { return false; } usleep( 100 ); } } TEST_CASE( "Handling other threads", "[attach][local][threads]" ) { using namespace std::chrono_literals; constexpr auto numThreads = 150; pthread_t threads[numThreads]; std::array threadDatas{}; for (auto i = 0; i < numThreads; i++) { pthread_create(&threads[i], nullptr, threadFunc, (void*)&threadDatas[i]); } SECTION( "running threads modify a value" ) { const int saved_b = g_b; resetLoopCountersAndWaitForLoop(threadDatas); CHECK( saved_b != g_b ); } SECTION( "registered threads are suspended during transaction and resumed afterwards" ) { SECTION( "manual thread managing" ) { CHECK( detour_transaction_begin() == err_none ); for (auto&& thread : threads) { CHECK( detour_manage_thread( pthread_mach_thread_np(thread) ) == err_none ); } } SECTION( "manage_all_threads" ) { CHECK( detour_transaction_begin() == err_none ); CHECK( detour_manage_all_threads() == err_none ); } SECTION( "transaction_begin_managed" ) { CHECK( detour_transaction_begin_managed() == err_none ); } int saved_b = g_b; resetLoopCountersAndWaitForLoop(threadDatas, 1s); CHECK( saved_b == g_b ); CHECK( detour_transaction_commit() == err_none ); saved_b = g_b; resetLoopCountersAndWaitForLoop(threadDatas); CHECK( saved_b != g_b ); } SECTION( "aborting a transaction resumes too" ) { CHECK( detour_transaction_begin_managed() == err_none ); CHECK( detour_transaction_abort() == err_none ); const int saved_b = g_b; resetLoopCountersAndWaitForLoop(threadDatas); CHECK( saved_b != g_b ); } SECTION( "thread's PC gets redirected on commit, if it was in an overridden area" ) { // Note that this test case doesn't guarantee that any thread is in the section of code that gets moved into the // trampoline. So if buggy it could be flaky. You can try using more threads to get a more reliable repro. CHECK( detour_transaction_begin_managed() == err_none ); CHECK( detour_attach_and_commit(reinterpret_cast(&realLocalFunction), reinterpret_cast(localFunctionDetour)) == err_none ); const int saved_b = g_b; resetLoopCountersAndWaitForLoop(threadDatas); CHECK( saved_b == g_b ); // clean up CHECK( detour_transaction_begin_managed() == err_none ); CHECK( detour_detach_and_commit(reinterpret_cast(&realLocalFunction), reinterpret_cast(localFunctionDetour)) == err_none ); } for (auto i = 0; i < numThreads; i++) { threadDatas[i].cancel = true; } for (auto i = 0; i < numThreads; i++) { pthread_join(threads[i], nullptr); } }