从另一个线程访问线程本地

发布于 2024-10-22 04:03:18 字数 644 浏览 2 评论 0原文

如何从另一个线程读取/写入线程局部变量?也就是说,在线程AI中想要访问线程B的线程本地存储区域中的变量。我知道另一个线程的 ID。

该变量在 GCC 中被声明为 __thread。目标平台是 Linux,但独立性可能会更好(但是特定于 GCC 也可以)。

由于缺少线程启动钩子,我无法在每个线程的启动处简单地跟踪该值。所有线程都需要以这种方式跟踪(不仅仅是专门启动的线程)。

不能选择像 boost thread_local_storage 这样的更高级别的包装器或使用 pthread 键。我需要使用真正的 __thread 局部变量的性能。


第一个答案是错误的:不能使用全局变量来完成我想做的事情。每个线程必须有自己的变量副本。此外,出于性能原因,这些变量必须是 __thread 变量(同样有效的解决方案也可以,但据我所知没有)。我也不控制线程入口点,因此这些线程不可能注册任何类型的结构。


Thread Local is not private:关于线程局部变量的另一个误解。这些绝不是线程的某种私有变量。它们是全局可寻址内存,但其生命周期与线程相关。任何线程中的任何函数,如果给出指向这些变量的指针,都可以修改它们。上面的问题本质上是关于如何获取该指针地址。

How can I read/write a thread local variable from another thread? That is, in Thread A I would like to access the variable in Thread B's thread local storage area. I know the ID of the other thread.

The variable is declared as __thread in GCC. Target platform is Linux, but independence might be nice (GCC specific is okay however).

Lacking a thread-start hook there is no way I can simply track this value at the start of each thread. All threads need to be tracked this way (not just specially started ones).

A higher level wrapper like boost thread_local_storage or using pthread keys is not an option. I require the performance of using a true __thread local variable.


FIRST ANSWER IS WRONG: One cannot use global variables for what I want to do. Each thread must have its own copy of the variable. Furthermore, those variables must be __thread variables for performance reasons (an equally efficient solution would also be okay, but I know of none). I also don't control the thread entry points, thus there is no possibility for those threads to register any kind of structure.


Thread Local is not private: Another misunderstanding about thread-local variables. These are in no way some kind of private variable for the thread. They are globally addressable memory, with the restriction that their lifetime is tied to the thread. Any function, from any thread, if given a pointer to these variables can modify them. The question above is essentially about how to get that pointer address.

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(5

香橙ぽ 2024-10-29 04:03:18

如果您想要非线程局部的线程局部变量,为什么不使用全局变量呢?

重要说明!

我并不是建议您使用单个全局变量来替换线程局部变量。我建议使用单个全局数组或其他合适的值集合来替换一个线程局部变量。

当然,您必须提供同步,但由于您希望将线程 A 中修改的值公开给线程 B,因此无法回避这一点。

更新:

关于 __thread< 的 GCC 文档/code>说:

当取址运算符为
应用于线程局部变量,它
在运行时评估并返回
当前线程的地址
该变量的实例。一个地址
这样获得的内容可以被任何线程使用。
当线程终止时,任何指针
到线程局部变量
线程无效。

因此,如果您坚持这样做,我想可以在线程生成后从它所属的线程获取线程局部变量的地址。然后,您可以将指向该内存位置的指针存储到映射(线程 id => 指针),并让其他线程以这种方式访问​​该变量。这假设您拥有生成线程的代码。

如果您确实喜欢冒险,您可以尝试在 ___tls_get_addr 上挖掘信息(从此 PDF 由上述 GCC 文档链接)。但这种方法是高度特定于编译器和平台的,并且缺乏文档,因此它应该引起任何人的注意。

If you want thread local variables that are not thread local, why don't you use global variables instead?

Important clarification!

I am not suggesting that you use a single global to replace a thread-local variable. I 'm suggesting of using a single global array or other suitable collection of values to replace one thread-local variable.

You will have to provide synchronization of course, but since you want to expose a value modified in thread A to thread B there's no getting around that.

Update:

The GCC documentation on __thread says:

When the address-of operator is
applied to a thread-local variable, it
is evaluated at run-time and returns
the address of the current thread's
instance of that variable. An address
so obtained may be used by any thread.
When a thread terminates, any pointers
to thread-local variables in that
thread become invalid.

Therefore, if you insist on going this way I imagine it's possible to get the address of a thread local variable from the thread it belongs to, just after the thread is spawned. You could then store a pointer to that memory location to a map (thread id => pointer), and let other threads access the variable this way. This assumes that you own the code for the spawned thread.

If you are really adventurous, you could try digging up information on ___tls_get_addr (start from this PDF which is linked to by the aforementioned GCC docs). But this approach is so highly compiler and platform specific and so lacking in documentation that it should be causing alarms to go off in anyone's head.

一百个冬季 2024-10-29 04:03:18

我正在寻找同样的东西。
我发现在以各种方式搜索网络后没有人回答你的问题,我得到了后续信息:假设在linux(ubuntu)上编译gcc并使用-m64,段寄存器gs的值为0。隐藏部分段的(保存线性地址)
指向线程特定的局部区域。
该区域包含该地址的地址(64 位)。较低地址存储所有线程局部变量。
该地址是native_handle()
因此,为了访问线程本地数据,您应该通过该指针来完成。

换句话说: (char*)&variable-(char*)myThread.native_handle()+(char*)theOtherThread.native_handle()

演示上述假设的代码 g++,linux,pthreads是:

#include <iostream>
#include <thread>
#include <sstream>

thread_local int B=0x11111111,A=0x22222222;

bool shouldContinue=false;

void code(){
    while(!shouldContinue);
    std::stringstream ss;
    ss<<" A:"<<A<<" B:"<<B<<std::endl;
    std::cout<<ss.str();
}

//#define ot(th,variable) 
//(*( (char*)&variable-(char*)(pthread_self())+(char*)(th.native_handle()) ))

int& ot(std::thread& th,int& v){
    auto p=pthread_self();
    intptr_t d=(intptr_t)&v-(intptr_t)p;
    return *(int*)((char*)th.native_handle()+d);
}

int main(int argc, char **argv)
{       

        std::thread th1(code),th2(code),th3(code),th4(code);

        ot(th1,A)=100;ot(th1,B)=110;
        ot(th2,A)=200;ot(th2,B)=210;
        ot(th3,A)=300;ot(th3,B)=310;
        ot(th4,A)=400;ot(th4,B)=410;

        shouldContinue=true;

        th1.join();
        th2.join();
        th3.join();
        th4.join();

    return 0;
}

I am searching for the same thing.
As I see nobody has answered your question after having searched the web in all ways I arrived to the subsequent information: supposing to compile for gcc on linux (ubuntu) and using -m64, the segment register gs holds the value 0. The hidden part of the segment (holding the linear address)
points to the thread specific local area.
That area contains at that address the address of that address ( 64 bits ). At lower addresses are stored all thread local variables.
That address is the native_handle().
So in order to access a threads local data you should do it via that pointer.

In other words: (char*)&variable-(char*)myThread.native_handle()+(char*)theOtherThread.native_handle()

The code that demonstrates the above supposing g++,linux,pthreads is:

#include <iostream>
#include <thread>
#include <sstream>

thread_local int B=0x11111111,A=0x22222222;

bool shouldContinue=false;

void code(){
    while(!shouldContinue);
    std::stringstream ss;
    ss<<" A:"<<A<<" B:"<<B<<std::endl;
    std::cout<<ss.str();
}

//#define ot(th,variable) 
//(*( (char*)&variable-(char*)(pthread_self())+(char*)(th.native_handle()) ))

int& ot(std::thread& th,int& v){
    auto p=pthread_self();
    intptr_t d=(intptr_t)&v-(intptr_t)p;
    return *(int*)((char*)th.native_handle()+d);
}

int main(int argc, char **argv)
{       

        std::thread th1(code),th2(code),th3(code),th4(code);

        ot(th1,A)=100;ot(th1,B)=110;
        ot(th2,A)=200;ot(th2,B)=210;
        ot(th3,A)=300;ot(th3,B)=310;
        ot(th4,A)=400;ot(th4,B)=410;

        shouldContinue=true;

        th1.join();
        th2.join();
        th3.join();
        th4.join();

    return 0;
}
别念他 2024-10-29 04:03:18

这是一个老问题,但既然没有给出答案,为什么不使用一个有自己静态注册的类呢?

#include <mutex>
#include <thread>
#include <unordered_map>

struct foo;

static std::unordered_map<std::thread::id, foo*> foos;
static std::mutex foos_mutex;

struct foo
{
    foo()
    {
        std::lock_guard<std::mutex> lk(foos_mutex);
        foos[std::this_thread::get_id()] = this;
    }
};

static thread_local foo tls_foo;

当然,您需要在线程之间进行某种同步,以确保线程已注册指针,但是您可以从知道线程 ID 的任何线程的映射中获取它。

This is an old question, but since there is no answer given, why not use a class that has its own static registration?

#include <mutex>
#include <thread>
#include <unordered_map>

struct foo;

static std::unordered_map<std::thread::id, foo*> foos;
static std::mutex foos_mutex;

struct foo
{
    foo()
    {
        std::lock_guard<std::mutex> lk(foos_mutex);
        foos[std::this_thread::get_id()] = this;
    }
};

static thread_local foo tls_foo;

Of course you would need some kind of synchronization between the threads to ensure that the thread had registered the pointer, but you can then grab it from the map from any thread where you know the thread's id.

戏蝶舞 2024-10-29 04:03:18

不幸的是我一直无法找到一种方法来做到这一点。

如果没有某种线程初始化挂钩,似乎就没有办法获取该指针(缺少依赖于平台的 ASM 黑客)。

I was unfortunately never able to find a way to do this.

Without some kind of thread init hook there just doesn't appear to be a way to get at that pointer (short of ASM hacks that would be platform dependent).

月亮是我掰弯的 2024-10-29 04:03:18

这几乎可以满足您的需要,甚至可以根据您的要求进行修改。

在 Linux 上,它使用 pthread_key_create,而 Windows 使用 TlsAlloc。它们都是通过key检索本地线程的方法。但是,如果您注册密钥,则可以访问其他线程上的数据。

EnumerableThreadLocal 的想法是,您在线程中执行本地操作,然后将结果减少回主线程中。

tbb 有一个类似的函数,称为 enumerable_thread_spec ,其动机可以在 https://oneapi-src.github.io/oneTBB/main/tbb_userguide/design_patterns/Divide_and_Conquer.html

下面是在不依赖 tbb 的情况下模仿 tbb 代码的尝试。下面代码的缺点是 Windows 上的键数被限制为 1088 个。

    template <typename T>
    class EnumerableThreadLocal
    {

#if _WIN32 || _WIN64
        using tls_key_t = DWORD;
        void create_key() { my_key = TlsAlloc(); }
        void destroy_key() { TlsFree(my_key); }
        void set_tls(void *value) { TlsSetValue(my_key, (LPVOID)value); }
        void *get_tls() { return (void *)TlsGetValue(my_key); }
#else
        using tls_key_t = pthread_key_t;
        void create_key() { pthread_key_create(&my_key, nullptr); }
        void destroy_key() { pthread_key_delete(my_key); }
        void set_tls(void *value) const { pthread_setspecific(my_key, value); }
        void *get_tls() const { return pthread_getspecific(my_key); }
#endif
        std::vector<std::pair<std::thread::id, std::unique_ptr<T>>> m_thread_locals;
        std::mutex m_mtx;
        tls_key_t my_key;

        using Factory = std::function<std::unique_ptr<T>()>;
        Factory m_factory;

        static auto DefaultFactory()
        {
            return std::make_unique<T alignas(hardware_constructive_interference_size)>();
        }

    public:

        EnumerableThreadLocal(Factory factory = &DefaultFactory ) : m_factory(factory)
        {
            create_key();
        }

        ~EnumerableThreadLocal()
        {
            destroy_key();
        }

        EnumerableThreadLocal(const EnumerableThreadLocal &other)
        {
            create_key();
            // deep copy the m_thread_locals
            m_thread_locals.reserve(other.m_thread_locals.size());
            for (const auto &pair : other.m_thread_locals)
            {
                m_thread_locals.emplace_back(pair.first, std::make_unique<T>(*pair.second));
            }
        }

        EnumerableThreadLocal &operator=(const EnumerableThreadLocal &other)
        {
            if (this != &other)
            {
                destroy_key();
                create_key();
                m_thread_locals.clear();
                m_thread_locals.reserve(other.m_thread_locals.size());
                for (const auto &pair : other.m_thread_locals)
                {
                    m_thread_locals.emplace_back(pair.first, std::make_unique<T>(*pair.second));
                }
            }
            return *this;
        }

        EnumerableThreadLocal(EnumerableThreadLocal &&other) noexcept
        {
            // deep move
            my_key = other.my_key;
            // deep move the m_thread_locals
            m_thread_locals = std::move(other.m_thread_locals);
            other.my_key = 0;

        }

        EnumerableThreadLocal &operator=(EnumerableThreadLocal &&other) noexcept
        {
            if (this != &other)
            {
                destroy_key();
                my_key = other.my_key;
                m_thread_locals = std::move(other.m_thread_locals);
                other.my_key = 0;
            }
            return *this;
        }

        T *Get ()
        {
            void *v = get_tls();
            if (v)
            {
                return reinterpret_cast<T *>(v);
            }
            else
            {
                const std::scoped_lock l(m_mtx);
                for (const auto &[thread_id, uptr] : m_thread_locals)
                {
                    // This search is necessary for the case if we run out of TLS indicies in customer's process, and we do at least slow lookup
                    if (thread_id == std::this_thread::get_id())
                    {
                        set_tls(reinterpret_cast<void *>(uptr.get()));
                        return uptr.get();
                    }
                }

                m_thread_locals.emplace_back(std::this_thread::get_id(), m_factory());
                T *ptr = m_thread_locals.back().second.get();
                set_tls(reinterpret_cast<void *>(ptr));
                return ptr;
            }
        }

        T const * Get() const
        {
            return const_cast<EnumerableThreadLocal *>(this)->Get();
        }

        T & operator *()
        {
            return *Get();
        }

        T const & operator *() const
        {
            return *Get();
        }

        T * operator ->()
        {
            return Get();
        }

        T const * operator ->() const
        {
            return Get();
        }

        template <typename F>
        void Enumerate(F fn)
        {
            const std::scoped_lock lock(m_mtx);
            for (auto &[thread_id, ptr] : m_thread_locals)
                fn(*ptr);
        }
    };

以及一套测试用例来向您展示它是如何工作的

#include <thread>
#include <string>
#include "gtest/gtest.h"
#include "EnumerableThreadLocal.hpp"

TEST(EnumerableThreadLocal, BasicTest)
{
    const int N = 10;
    v31::EnumerableThreadLocal<std::string> tls;

    // Create N threads and assign a string including the thread ID to the tls
    std::vector<std::thread> threads;
    for (int i = 0; i < N; ++i)
    {
        threads.emplace_back([&tls, i]()
                             { *tls = "Thread " + std::to_string(i); });
    }

    // Wait for all threads to finish
    for (auto &thread : threads)
        thread.join();

    std::vector<std::string> expected;
    tls.Enumerate([&](std::string &s)
                  { expected.push_back(s); });

    // Sort the expected vector
    std::sort(expected.begin(), expected.end());

    // check the expected vector
    for (int i = 0; i < N; ++i)
    {
        ASSERT_EQ(expected[i], "Thread " + std::to_string(i));
    }


}

// Create a non copyable type, non moveable type
struct NonCopyable
{
    int i=0;
    NonCopyable() = default;
    NonCopyable(const NonCopyable &) = delete;
    NonCopyable(NonCopyable &&) = delete;
    NonCopyable &operator=(const NonCopyable &) = delete;
    NonCopyable &operator=(NonCopyable &&) = delete;
};

// A test to see if we can insert non moveable/ non copyable types to the tls
TEST(EnumerableThreadLocal, NonCopyableTest)
{
    const int N = 10;
    v31::EnumerableThreadLocal<NonCopyable> tls;

    // Create N threads and assign a string including the thread ID to the tls
    std::vector<std::thread> threads;
    for (int i = 0; i < N; ++i)
    {
        threads.emplace_back([&tls, i]()
                             { tls->i=i; });
    }

    // Wait for all threads to finish
    for (auto &thread : threads)
        thread.join();

    std::vector<int> expected;
    tls.Enumerate([&](NonCopyable &s)
                  { expected.push_back(s.i); });

    // Sort the expected vector
    std::sort(expected.begin(), expected.end());

    // check the expected vector
    for (int i = 0; i < N; ++i)
    {
        ASSERT_EQ(expected[i], i);
    }
}

const int N = 10;
v31::EnumerableThreadLocal<std::string> CreateFixture()
{
    v31::EnumerableThreadLocal<std::string> tls;

    // Create N threads and assign a string including the thread ID to the tls
    std::vector<std::thread> threads;
    for (int i = 0; i < N; ++i)
    {
        threads.emplace_back([&tls, i]()
                             { *tls = "Thread " + std::to_string(i); });
    }

    // Wait for all threads to finish
    for (auto &thread : threads)
        thread.join();

    return tls;
}

void CheckFixtureCopy(v31::EnumerableThreadLocal<std::string> & tls)
{
    std::vector<std::string> expected;

    tls.Enumerate([&](std::string &s)
                    { expected.push_back(s); });

    // Sort the expected vector
    std::sort(expected.begin(), expected.end());

    // check the expected vector
    for (int i = 0; i < N; ++i)
    {
        ASSERT_EQ(expected[i], "Thread " + std::to_string(i));
    }
}

void CheckFixtureEmpty(v31::EnumerableThreadLocal<std::string> & tls)
{
    std::vector<std::string> expected;

    tls.Enumerate([&](std::string &s)
                    { expected.push_back(s); });

    ASSERT_EQ(expected.size(), 0);
}

/// Test for copy construct of EnumerableThreadLocal
TEST(EnumerableThreadLocal, Copy)
{
    auto tls = CreateFixture();
    // Copy the tls
    auto tls_copy = tls;

    CheckFixtureCopy(tls_copy);
    CheckFixtureCopy(tls);
}

/// Test for move construct of EnumerableThreadLocal
TEST(EnumerableThreadLocal, Move)
{
    auto tls = CreateFixture();
    // Copy the tls
    auto tls_copy = std::move(tls);

    CheckFixtureCopy(tls_copy);
    CheckFixtureEmpty(tls);
}

/// Test for copy assign of EnumerableThreadLocal
TEST(EnumerableThreadLocal, CopyAssign)
{
    auto tls = CreateFixture();
    // Copy the tls
    v31::EnumerableThreadLocal<std::string> tls_copy;
    CheckFixtureEmpty(tls_copy);
    tls_copy = tls;

    CheckFixtureCopy(tls_copy);
    CheckFixtureCopy(tls);
}   

/// Test for move assign of EnumerableThreadLocal
TEST(EnumerableThreadLocal, MoveAssign)
{
    auto tls = CreateFixture();
    // Copy the tls
    v31::EnumerableThreadLocal<std::string> tls_copy;
    CheckFixtureEmpty(tls_copy);
    tls_copy = std::move(tls);

    CheckFixtureCopy(tls_copy);
    CheckFixtureEmpty(tls);
}

//class with no default constructor
struct NoDefaultConstructor
{
    int i;
    NoDefaultConstructor(int i) : i(i) {}
};

// Test for using objects with no default constructor
TEST(EnumerableThreadLocal, NoDefaultConstructor)
{
    const int N = 10;
    v31::EnumerableThreadLocal<NoDefaultConstructor> tls([]{return std::make_unique<NoDefaultConstructor>(0);});

    // Create N threads and assign a string including the thread ID to the tls
    std::vector<std::thread> threads;
    for (int i = 0; i < N; ++i)
    {
        threads.emplace_back([&tls, i]()
                             { tls->i = i; });
    }

    // Wait for all threads to finish
    for (auto &thread : threads)
        thread.join();

    // enumerate and sort and verify
    std::vector<int> expected;  
    tls.Enumerate([&](NoDefaultConstructor &s)
                    { expected.push_back(s.i); });

    // Sort the expected vector
    std::sort(expected.begin(), expected.end());

    // check the expected vector
    for (int i = 0; i < N; ++i)
    {
        ASSERT_EQ(expected[i], i);
    }

}

This pretty much does what you need and if not modify to your requirements.

On linux it uses pthread_key_create and windows uses TlsAlloc. They are both a way of retrieving a thread local by key. Howevever if you register the keys you can then access the data on other threads.

The idea of EnumerableThreadLocal is that you perform a local operation in your threads and then reduce the results back down in your main thread.

tbb has a similar function called enumerable_thread_specific and the motiviation for it can be found at https://oneapi-src.github.io/oneTBB/main/tbb_userguide/design_patterns/Divide_and_Conquer.html

The below was an attempt to mimic the tbb code without having a dependency on tbb. The downside with the below code is you are limited to 1088 keys on windows.

    template <typename T>
    class EnumerableThreadLocal
    {

#if _WIN32 || _WIN64
        using tls_key_t = DWORD;
        void create_key() { my_key = TlsAlloc(); }
        void destroy_key() { TlsFree(my_key); }
        void set_tls(void *value) { TlsSetValue(my_key, (LPVOID)value); }
        void *get_tls() { return (void *)TlsGetValue(my_key); }
#else
        using tls_key_t = pthread_key_t;
        void create_key() { pthread_key_create(&my_key, nullptr); }
        void destroy_key() { pthread_key_delete(my_key); }
        void set_tls(void *value) const { pthread_setspecific(my_key, value); }
        void *get_tls() const { return pthread_getspecific(my_key); }
#endif
        std::vector<std::pair<std::thread::id, std::unique_ptr<T>>> m_thread_locals;
        std::mutex m_mtx;
        tls_key_t my_key;

        using Factory = std::function<std::unique_ptr<T>()>;
        Factory m_factory;

        static auto DefaultFactory()
        {
            return std::make_unique<T alignas(hardware_constructive_interference_size)>();
        }

    public:

        EnumerableThreadLocal(Factory factory = &DefaultFactory ) : m_factory(factory)
        {
            create_key();
        }

        ~EnumerableThreadLocal()
        {
            destroy_key();
        }

        EnumerableThreadLocal(const EnumerableThreadLocal &other)
        {
            create_key();
            // deep copy the m_thread_locals
            m_thread_locals.reserve(other.m_thread_locals.size());
            for (const auto &pair : other.m_thread_locals)
            {
                m_thread_locals.emplace_back(pair.first, std::make_unique<T>(*pair.second));
            }
        }

        EnumerableThreadLocal &operator=(const EnumerableThreadLocal &other)
        {
            if (this != &other)
            {
                destroy_key();
                create_key();
                m_thread_locals.clear();
                m_thread_locals.reserve(other.m_thread_locals.size());
                for (const auto &pair : other.m_thread_locals)
                {
                    m_thread_locals.emplace_back(pair.first, std::make_unique<T>(*pair.second));
                }
            }
            return *this;
        }

        EnumerableThreadLocal(EnumerableThreadLocal &&other) noexcept
        {
            // deep move
            my_key = other.my_key;
            // deep move the m_thread_locals
            m_thread_locals = std::move(other.m_thread_locals);
            other.my_key = 0;

        }

        EnumerableThreadLocal &operator=(EnumerableThreadLocal &&other) noexcept
        {
            if (this != &other)
            {
                destroy_key();
                my_key = other.my_key;
                m_thread_locals = std::move(other.m_thread_locals);
                other.my_key = 0;
            }
            return *this;
        }

        T *Get ()
        {
            void *v = get_tls();
            if (v)
            {
                return reinterpret_cast<T *>(v);
            }
            else
            {
                const std::scoped_lock l(m_mtx);
                for (const auto &[thread_id, uptr] : m_thread_locals)
                {
                    // This search is necessary for the case if we run out of TLS indicies in customer's process, and we do at least slow lookup
                    if (thread_id == std::this_thread::get_id())
                    {
                        set_tls(reinterpret_cast<void *>(uptr.get()));
                        return uptr.get();
                    }
                }

                m_thread_locals.emplace_back(std::this_thread::get_id(), m_factory());
                T *ptr = m_thread_locals.back().second.get();
                set_tls(reinterpret_cast<void *>(ptr));
                return ptr;
            }
        }

        T const * Get() const
        {
            return const_cast<EnumerableThreadLocal *>(this)->Get();
        }

        T & operator *()
        {
            return *Get();
        }

        T const & operator *() const
        {
            return *Get();
        }

        T * operator ->()
        {
            return Get();
        }

        T const * operator ->() const
        {
            return Get();
        }

        template <typename F>
        void Enumerate(F fn)
        {
            const std::scoped_lock lock(m_mtx);
            for (auto &[thread_id, ptr] : m_thread_locals)
                fn(*ptr);
        }
    };

and a suite of test cases to show you how it works

#include <thread>
#include <string>
#include "gtest/gtest.h"
#include "EnumerableThreadLocal.hpp"

TEST(EnumerableThreadLocal, BasicTest)
{
    const int N = 10;
    v31::EnumerableThreadLocal<std::string> tls;

    // Create N threads and assign a string including the thread ID to the tls
    std::vector<std::thread> threads;
    for (int i = 0; i < N; ++i)
    {
        threads.emplace_back([&tls, i]()
                             { *tls = "Thread " + std::to_string(i); });
    }

    // Wait for all threads to finish
    for (auto &thread : threads)
        thread.join();

    std::vector<std::string> expected;
    tls.Enumerate([&](std::string &s)
                  { expected.push_back(s); });

    // Sort the expected vector
    std::sort(expected.begin(), expected.end());

    // check the expected vector
    for (int i = 0; i < N; ++i)
    {
        ASSERT_EQ(expected[i], "Thread " + std::to_string(i));
    }


}

// Create a non copyable type, non moveable type
struct NonCopyable
{
    int i=0;
    NonCopyable() = default;
    NonCopyable(const NonCopyable &) = delete;
    NonCopyable(NonCopyable &&) = delete;
    NonCopyable &operator=(const NonCopyable &) = delete;
    NonCopyable &operator=(NonCopyable &&) = delete;
};

// A test to see if we can insert non moveable/ non copyable types to the tls
TEST(EnumerableThreadLocal, NonCopyableTest)
{
    const int N = 10;
    v31::EnumerableThreadLocal<NonCopyable> tls;

    // Create N threads and assign a string including the thread ID to the tls
    std::vector<std::thread> threads;
    for (int i = 0; i < N; ++i)
    {
        threads.emplace_back([&tls, i]()
                             { tls->i=i; });
    }

    // Wait for all threads to finish
    for (auto &thread : threads)
        thread.join();

    std::vector<int> expected;
    tls.Enumerate([&](NonCopyable &s)
                  { expected.push_back(s.i); });

    // Sort the expected vector
    std::sort(expected.begin(), expected.end());

    // check the expected vector
    for (int i = 0; i < N; ++i)
    {
        ASSERT_EQ(expected[i], i);
    }
}

const int N = 10;
v31::EnumerableThreadLocal<std::string> CreateFixture()
{
    v31::EnumerableThreadLocal<std::string> tls;

    // Create N threads and assign a string including the thread ID to the tls
    std::vector<std::thread> threads;
    for (int i = 0; i < N; ++i)
    {
        threads.emplace_back([&tls, i]()
                             { *tls = "Thread " + std::to_string(i); });
    }

    // Wait for all threads to finish
    for (auto &thread : threads)
        thread.join();

    return tls;
}

void CheckFixtureCopy(v31::EnumerableThreadLocal<std::string> & tls)
{
    std::vector<std::string> expected;

    tls.Enumerate([&](std::string &s)
                    { expected.push_back(s); });

    // Sort the expected vector
    std::sort(expected.begin(), expected.end());

    // check the expected vector
    for (int i = 0; i < N; ++i)
    {
        ASSERT_EQ(expected[i], "Thread " + std::to_string(i));
    }
}

void CheckFixtureEmpty(v31::EnumerableThreadLocal<std::string> & tls)
{
    std::vector<std::string> expected;

    tls.Enumerate([&](std::string &s)
                    { expected.push_back(s); });

    ASSERT_EQ(expected.size(), 0);
}

/// Test for copy construct of EnumerableThreadLocal
TEST(EnumerableThreadLocal, Copy)
{
    auto tls = CreateFixture();
    // Copy the tls
    auto tls_copy = tls;

    CheckFixtureCopy(tls_copy);
    CheckFixtureCopy(tls);
}

/// Test for move construct of EnumerableThreadLocal
TEST(EnumerableThreadLocal, Move)
{
    auto tls = CreateFixture();
    // Copy the tls
    auto tls_copy = std::move(tls);

    CheckFixtureCopy(tls_copy);
    CheckFixtureEmpty(tls);
}

/// Test for copy assign of EnumerableThreadLocal
TEST(EnumerableThreadLocal, CopyAssign)
{
    auto tls = CreateFixture();
    // Copy the tls
    v31::EnumerableThreadLocal<std::string> tls_copy;
    CheckFixtureEmpty(tls_copy);
    tls_copy = tls;

    CheckFixtureCopy(tls_copy);
    CheckFixtureCopy(tls);
}   

/// Test for move assign of EnumerableThreadLocal
TEST(EnumerableThreadLocal, MoveAssign)
{
    auto tls = CreateFixture();
    // Copy the tls
    v31::EnumerableThreadLocal<std::string> tls_copy;
    CheckFixtureEmpty(tls_copy);
    tls_copy = std::move(tls);

    CheckFixtureCopy(tls_copy);
    CheckFixtureEmpty(tls);
}

//class with no default constructor
struct NoDefaultConstructor
{
    int i;
    NoDefaultConstructor(int i) : i(i) {}
};

// Test for using objects with no default constructor
TEST(EnumerableThreadLocal, NoDefaultConstructor)
{
    const int N = 10;
    v31::EnumerableThreadLocal<NoDefaultConstructor> tls([]{return std::make_unique<NoDefaultConstructor>(0);});

    // Create N threads and assign a string including the thread ID to the tls
    std::vector<std::thread> threads;
    for (int i = 0; i < N; ++i)
    {
        threads.emplace_back([&tls, i]()
                             { tls->i = i; });
    }

    // Wait for all threads to finish
    for (auto &thread : threads)
        thread.join();

    // enumerate and sort and verify
    std::vector<int> expected;  
    tls.Enumerate([&](NoDefaultConstructor &s)
                    { expected.push_back(s.i); });

    // Sort the expected vector
    std::sort(expected.begin(), expected.end());

    // check the expected vector
    for (int i = 0; i < N; ++i)
    {
        ASSERT_EQ(expected[i], i);
    }

}
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文