130行C语言实现个用户态线程库(1)

准确的说是除掉头文件,测试代码和非关键的纯算法代码(只有双向环形链表的ADT),核心代码只有130行左右,已是蝇量级的用户态线程库了。把这个库取名为ezthread,意思是,这太easy了,人人均可以读懂而且实现这个用户态线程库。我把该项目放在github上,欢迎来拍砖: https://github.com/Yuandong-Chen/coroutine/tree/old-version(注意,最新的版本已经用了共享栈技术,可以支持1000K数量级的协程了,读完这篇博文后能够进一步参考后续的博文:http://www.cnblogs.com/github-Yuandong-Chen/p/6973932.html)。那么下面谈谈怎么实现这个ezthread。html

你们都会双向环形链表(就是头尾相连的双向链表),咱们构造这个ADT结构:node

首先是每一个节点:git

1 typedef struct __pnode pNode;
2 struct __pnode
3 {
4     pNode *next;
5     pNode *prev;
6     Thread_t *data;
7 };

显然,next指向下一个节点,prev指向上一个节点,data指向该节点数据,那么这个Thread_t是什么类型的数据结构呢?github

typedef struct __ez_thread Thread_t;
struct __ez_thread
{
    Regs regs;
    int tid;
    unsigned int stacktop;
    unsigned int stacksize;
    void *stack;
    void *retval;
};

这个结构体包含了线程内部的信息,好比第一项为Regs,记录的是各个寄存器的取值(咱们在下面给出具体的结构),tid就是线程的ID了,stacktop记录的是线程栈的顶部(和页对齐的最大地址,每一个线程都有本身的运行时的栈,用于构成他们相对独立的运行时环境),stacksize就是栈的大小了,stack指针指向咱们给该线程栈分配的堆的指针(什么?怎么一会栈一会堆的?咱们其实用了malloc函数分配出一些堆空间,把这些空间用于线程栈,当线程退出时候,咱们再free这些堆),retval就是线程运行完了的返回值(pthread_join里头拿到的线程返回值就是这个了)。算法

下面是寄存器结构体:数据结构

typedef struct __thread_table_regs Regs;
struct __thread_table_regs
{
    int _ebp;
    int _esp;
    int _eip;
    int _eflags;
};

真是好懂,一看就知道了,这个结构体只能支持X86体系的计算机了。那么还有个问题,为什么只有这些寄存器,没用其余的好比:eax,ebx,edi,esi等等呢?由于咱们在转换状态函数switch_to里头当返回时(准确地说是从上次切换的点切换回来时)用了return来切换回线程运行时环境,return会自动帮助咱们把这些其余的寄存器的值恢复原状的(具体咱们放到switch_to的时候再详细说明)。多线程

而后呢,咱们定义了一个游标去取这个环形链表的值,不然咱们怎么读取这个环形链表里头的数据呢?总得有个东西指向其中某个节点吧。app

typedef struct __loopcursor Cursor;
struct __loopcursor
{
    int total;
    pNode *current;
};

这个游标结构体记录了如今指向的节点地址和这个环形链表里头一共有多少节点。ide

咱们得用两个这样的环形链表结构体来支持咱们的线程库,为什么是俩呢?一个是正在运行的线程,咱们把他们串成一个环形链表,取名为live(活的),而后用另一个链表把运行结束的线程串成一串,取名为dead(死的)。而后最开始咱们就有个线程在运行了,那就是主线程main,咱们用pmain节点来记录主线程:函数

extern Cursor live;
extern Cursor dead;
extern Thread_t pmain;

好了,剩下的只有在这些结构体上操做的函数了:

void init();
void switch_to(int, ...);
int threadCreat(Thread_t **, void *(*)(void *), void *);
int threadJoin(Thread_t *, void **);

咱们开始时调用init,以初始化咱们的live,dead和pmain。而后当咱们想创造线程时,就threadCreat就能够了,用法和pthread_create基本如出一辙,熟悉posix多线程的人一看就明白了,threadJoin也是仿照pthread_join接口写的。这里的switch_to就是最关键的运行时环境转换函数了,当线程调用这个函数时候,咱们就切换到其余线程上次暂停的点去执行了(这些状态都保存在咱们的Thread_t结构体里,因此咱们可以记录下切换前的状态,从而可以从容地去切换到下一个线程中)。咱们没有用定时器每隔几微秒去激发switch_to(实现起来也是很是简单的,可是得添加多个signal_block函数,很是不简洁),而是让线程里头的函数主动调用switch_to来切换线程,这有点相似协程。

好了,如今讲具体的实现了。首先是对双向链表的操做函数,这个东西不是咱们的重点,懂基础算法数据结构的人都能实现,具体是双向环形链表的增查删操做:

 1 void initCursor(Cursor *cur)
 2 {
 3     cur->total = 0;
 4     cur->current = NULL;
 5 }
 6 
 7 Thread_t *findThread(Cursor *cur, int tid)
 8 {
 9     int counter = cur->total;
10     if(counter == 0){
11         return NULL;
12     }
13 
14     int i;
15     pNode *tmp = cur->current;
16     for (int i = 0; i < counter; ++i)
17     {
18         if((tmp->data)->tid == tid){
19             return tmp->data;
20         }
21 
22         tmp = tmp->next;
23     }
24     return NULL;
25 }
26 
27 int appendThread(Cursor *cur, Thread_t *pth)
28 {
29     if(cur->total == 0)
30     {
31         cur->current = (pNode *)malloc(sizeof(pNode));
32         assert(cur->current);
33         (cur->current)->data = pth;
34         (cur->current)->prev = cur->current;
35         (cur->current)->next = cur->current;
36         cur->total++;
37         return 0;
38     }
39     else
40     {    
41         if(cur->total > MAXCOROUTINES)
42         {
43             assert((cur->total == MAXCOROUTINES));
44             return -1;
45         }
46         
47         pNode *tmp = malloc(sizeof(pNode));
48         assert(tmp);
49         tmp->data = pth;
50         tmp->prev = cur->current;
51         tmp->next = (cur->current)->next;
52         ((cur->current)->next)->prev = tmp;
53         (cur->current)->next = tmp;
54         cur->total++;
55         return 0;
56     }
57 }
58 
59 pNode *deleteThread(Cursor *cur, int tid)
60 {
61     int counter = cur->total;
62     int i;
63     pNode *tmp = cur->current;
64     for (int i = 0; i < counter; ++i)
65     {
66         if((tmp->data)->tid == tid){
67             (tmp->prev)->next = tmp->next;
68             (tmp->next)->prev = tmp->prev;
69             if(tmp == cur->current)
70             {
71                 cur->current = cur->current->next;
72             }  
73 
74             cur->total--;
75             assert(cur->total >= 0);
76             return tmp;
77         }
78         tmp = tmp->next;
79     }
80     return NULL;
81 }
双向链表操做函数

抛开这部分纯算法代码,咱们只剩下130行代码了。这还不如某些函数的代码量大。可是咱们就是在这130行代码里头实现了switch_to,threadCreat以及threadJoin等等关键代码。

先说下init怎么实现的:

1 void init()
2 {
3     initCursor(&live);
4     initCursor(&dead);
5     appendThread(&live, &pmain);
6 }

其实关键点只有一句,那就是第5行的append(&live,&pmain);往live链表里头添加pmain节点,可是咱们的pmain还没初始化呢,里头stack,regs等等统统都是0,可是没事呢,由于当咱们第一次进入switch_to的时候,switch_to在跳转前会帮助咱们保存当前线程,这时也就是pmain的运行时状态。

而后咱们看看threadCreat怎么实现:

 1 int threadCreat(Thread_t **pth, void *(*start_rtn)(void *), void *arg)
 2 {
 3 
 4     *pth = malloc(sizeof(Thread_t));
 5     (*pth)->stack = malloc(PTHREAD_STACK_MIN);
 6     assert((*pth)->stack);
 7     (*pth)->stacktop = (((int)(*pth)->stack + PTHREAD_STACK_MIN)&(0xfffff000));
 8     (*pth)->stacksize = PTHREAD_STACK_MIN - (((int)(*pth)->stack + PTHREAD_STACK_MIN) - (*pth)->stacktop);
 9     (*pth)->tid = fetchTID();
10     /* set params */
11     void *dest = (*pth)->stacktop - 12;
12     memcpy(dest, pth, 4);
13     dest += 4;
14     memcpy(dest, &start_rtn, 4);
15     dest += 4;
16     memcpy(dest, &arg, 4);
17     (*pth)->regs._eip = &real_entry;
18     (*pth)->regs._esp = (*pth)->stacktop - 16;
19     (*pth)->regs._ebp = 0;
20     appendThread(&live, (*pth));
21 
22     return 0;
23 }

咱们在第4行分配了堆空间,而后让线程栈顶变量stacktop对齐页,设置stacksize大小(这个其实对咱们的线程库没有用,由于咱们尚未实现相似stackguard之类的检查机制),设置tid,这里fetchTID函数以下:

1 int fetchTID()
2 {
3     static int tid;
4     return ++tid;
5 }

接着,咱们在threadCreat函数的11-16行代码中,在栈顶压入变量pth,start_rtn以及arg(咱们用memcpy来操做线程栈空间),这些都是做为real_entry这个函数的参数压入线程栈的。咱们不难发现,其实每一个线程的最初入口地址都是real_entry函数(注意到咱们在17行把eip设置为real_entry的地址)。最后,咱们于17-19行设置寄存器变量,以知足刚进入该real_entry时的栈的状态,在live链表中添加该线程结构体指针,返回。这一系列操做致使的效果就是,好比咱们第一次调用threadCreat函数,当发生switch_to的时候,固然咱们先保存当前线程状态,而后就从主线程main中切换到了real_entry里头去了,并且对应的参数咱们设置好了,就好像咱们在主线程里头直接调用了real_entry同样。下面看下real_entry作了些什么:

 1 void real_entry(Thread_t *pth, void *(*start_rtn)(void *), void* args)
 2 {
 3     ALIGN();
 4 
 5     pth->retval = (*start_rtn)(args);
 6 
 7     deleteThread(&live, pth->tid);
 8     appendThread(&dead, pth);
 9 
10     switch_to(-1);
11 }

 

 第3行是对齐栈操做,咱们先不作说明。接下来就是调用start_rtn函数,而且把args做为参数,返回值赋给线程的retval。当返回时,说明线程已经运行结束,在live链表里头删除该节点,在dead链表里头添加该节点。在第10行最后调用switch_to(-1),也就是在switch_to里头直接跳到下一个线程去执行,且不保存当前状态。

咱们再看下threadJoin函数的实现:

 1 int threadJoin(Thread_t *pth, void **rval_ptr)
 2 {
 3 
 4     Thread_t *find1, *find2;
 5     find1 = findThread(&live, pth->tid);
 6     find2 = findThread(&dead, pth->tid);
 7     
 8 
 9     if((find1 == NULL)&&(find2 == NULL)){
10         
11         return -1;
12     }
13 
14     if(find2){
15         if(rval_ptr != NULL)
16             *rval_ptr = find2->retval;
17 
18         pNode *tmp = deleteThread(&dead, pth->tid);
19         free(tmp);
20         free((Stack_t)find2->stack);
21         free(find2);
22         return 0;
23     }
24 
25     while(1)
26     {
27         switch_to(0);
28         if((find2 = findThread(&dead, pth->tid))!= NULL){
29             if(rval_ptr!= NULL)
30                 *rval_ptr = find2->retval;
31 
32             pNode *tmp = deleteThread(&dead, pth->tid);
33             free(tmp);
34             free((Stack_t)find2->stack);
35             free(find2);
36             return 0;
37         }   
38     }
39     return -1;
40 }

threadJoin是用于回收线程资源并获得返回值的。实现大致的思路就是,咱们先查找live和dead里头有没有这个线程,若是都没有,说明根本不存在这个线程,若是dead链表里头有,那么咱们就获得返回值(15-16行),而后释放堆空间(19-22行)。若是在live里头,说明该线程还没执行结束,咱们进入循环,先调用switch_to(0),保存当前线程状态,而后切换到下一个线程去。当再次回到这个循环时候,咱们继续看看dead里头有没有这个线程,有就设置返回值(29-30行),而后释放资源(32-35行),不然继续切换并循环。

最后,最关键的,咱们给出switch_to的实现:

 1 void switch_to(int signo, ...)
 2 {
 3 
 4     va_list ap; 
 5     va_start(ap, signo);
 6 
 7     Regs regs;
 8 
 9     if(signo == -1)
10     {
11         regs = live.current->data->regs;
12         JMP(regs);
13         assert(0);
14     }
15     
16     int _ebp;
17     int _esp;
18     int _eip = &&_REENTERPOINT;
19     int _eflags;
20     /* save current context */
21     SAVE();
22     /* save context in current thread */
23     live.current->data->regs._eip = _eip;
24     live.current->data->regs._esp = _esp;
25     live.current->data->regs._ebp = _ebp;
26     live.current->data->regs._eflags = _eflags;
27 
28     if(va_arg(ap,int) == -1){
29  _REENTERPOINT:
30         assert(va_arg(ap,int) != -1);
31         return;
32     }
33 
34     va_end(ap);
35     regs = live.current->next->data->regs;
36     live.current = live.current->next;
37     JMP(regs);
38     assert(0);
39 }

先看11-13行,咱们把自动变量regs的值赋为当前线程的寄存器的结构体,而后跳转到当前线程(第12行JMP是跳转语句,13行永远不会执行)。这里你们有个疑问,从当前线程跳转到当前线程,那么还不是当前线程么?而后执行assert(0)报错退出?!其实只有当线程返回时,也就是在real_entry里头才可能执行switch_to(-1),注意到real_entry最后的几行代码,里头已经把当前线程从live里头删除,并添加到dead里了,因此如今live里头的当前线程实际上是下一个线程。而后咱们看21-26行,咱们保存当前寄存器的值到当前线程中,注意第18行,咱们把返回点设置在了_REENTERPOINT这个标签上,也就是之后若是再次切换到该线程时,咱们会在第30行继续向下执行,很简单,第30行的有意义的代码只有return,也就是恢复其余寄存器(eax,edi,esi等等),而后返回到线程继续执行。咱们继续看34-38行代码:咱们把自动变量regs的值赋值为下一个线程的寄存器,而后live的当前线程指针current也指向了下一个线程,经过37行JMP,咱们调到了下一个线程去执行,下个一个线程多是real_entry处开始执行,也多是_REENTERPOINT处开始执行。最后再重新说说31行的return到底return到哪里去了,咱们看一下测试代码:

 1 #include "ezthread.h"
 2 #include <stdio.h>
 3 #include <stdlib.h>
 4 
 5 void *sum1tod(void *d)
 6 {
 7     int i, j=0;
 8 
 9     for (i = 0; i <= d; ++i)
10     {
11         j += i;
12         printf("thread %d is grunting... %d\n",live.current->data->tid , i);
13         switch_to(0); // Give up control to next thread
14     }
15     
16     return ((void *)j);
17 }
18 
19 int main(int argc, char const *argv[])
20 {
21     int res = 0;
22     int i;
23     init();
24     Thread_t *tid1, *tid2;
25     int *res1, *res2;
26 
27     threadCreat(&tid1, sum1tod, 10);
28     threadCreat(&tid2, sum1tod, 10);
29 
30     for (i = 0; i <= 5; ++i){
31         res+=i;
32         printf("main is grunting... %d\n", i); 
33         switch_to(0); //Give up control to next thread
34     }
35     threadJoin(tid1, &res1); //Collect and Release the resourse of tid1
36     threadJoin(tid2, &res2); //Collect and Release the resourse of tid2
37     printf("parallel compute: %d = (1+2+3+4+5) + (1+2+...+10)*2\n", (int)res1+(int)res2+(int)res);
38     return 0;
39 }

注意到咱们在测试代码里头sum1tod里头调用了switch_to(0),若是这个循环加法(11-13行)还未结束,那么上述的那个_REENTERPOINT里头的return就会return回这个循环继续执行,就如在sum1tod里的switch_to(0)函数直接调用return,什么事情也没干同样,可是其实咱们通过了无数其余线程的执行,可是在sum1tod里头毫无感受,简直好像其余线程不存在同样(除非咱们在这里头调用threadJoin等待其余线程结束)。

如今咱们给出讨厌的内嵌汇编:

 1 #define JMP(r)    asm volatile \
 2                 (   \
 3                     "pushl %3\n\t" \
 4                     "popf\n\t" \
 5                     "movl %0, %%esp\n\t" \
 6                     "movl %2, %%ebp\n\t" \
 7                     "jmp *%1\n\t" \
 8                     : \
 9                     : "m"(r._esp),"a"(r._eip),"m"(r._ebp), "m"(r._eflags) \
10                     :  \
11                 )
12 
13 #define SAVE()                  asm volatile \
14                             (  \
15                                    "movl %%esp, %0\n\t" \
16                                 "movl %%ebp, %1\n\t" \
17                                 "pushf\n\t" \
18                                 "movl (%%esp), %%eax\n\t" \
19                                 "movl %%eax, %2\n\t" \
20                                 "popf\n\t" \
21                                 : "=m"(_esp),"=m"(_ebp), "=m"(_eflags) \ 
22                                 : \
23                                 :  \
24                             )
25 
26 #define ALIGN()             asm volatile \
27                             ( \
28                                 "andl $-16, %%esp\n\t" \
29                                 : \
30                                 : \
31                                 :"%esp" \
32                             )
inline asm

第一个就是起到跳转做用,第二个是保存寄存器到自动变量做用,最后一个是栈对齐做用。为什么要栈对齐?由于咱们在堆里头设置了这个栈的空间,这个和普通的栈空间并不彻底同样,咱们须要作对齐处理。

到这里咱们就几乎彻底明白了这个线程库的实现,还有一小点就是switch_to里头的可变参数怎么回事,其实那个是防止编译器中消除冗余代码形成咱们_REENTERPOINT中的代码被优化而整个删除用的。若是咱们在_REENTERPOINT前加入goto语句跳到下面执行,而后删除这个_REENTERPOINT以前的判断语句,咱们会发现,编译器会把switch_to里头的第28-32行做为冗余代码所有删除。

谢谢你能看到最后,告诉大家一个消息,其实咱们的实现是介于longjmp和汇编实现版本之间的某种实现:咱们用汇编保存了运行时状态,可是其中的return又有点相似longjmp中自动恢复寄存器的做用。并且咱们的库比纯汇编实现更具可移植性,但比longjmp实现版本又弱了点。

相关文章
相关标签/搜索