在遇到线程安全问题的时候,咱们通常都是使用同步来解决,好比内置锁、显示锁等等。线程安全的主要原由是由于多个线程同时操做一个共享变量,若是咱们换种思路,在某些场景下,咱们为这些线程提供共享变量的副本,让他们在本身的私有域中去操做这些变量,线程之间互不影响,那是否是就不会产生线程安全问题了?ThreadLocal提供了这样的一种实现。html
ThreadLocal内部封装了ThreadLocalMap结构来为线程提供存储数据的私有域空间,而Thread类提供了成员变量threadLocals来ThreadLocalMap,这样ThreadLocal、TreadLocalMap、Thread就紧密联系起来了。ThreadLocal对外提供了get、set、remove等方法来供咱们操做Thread的私有域空间ThreadLocalMap。这里咱们先说个大概,后面分析源码的时候再来一一解释。java
接下来直接看ThreadLocal的源码。数组
ThreadLocal的类结构安全
ThreadLocal的类是java.lang包下的一个普通类,没有任何类的继承与接口实现。数据结构
public class ThreadLocal<T> { ...... }
ThreadLocal的成员变量this
private final int threadLocalHashCode = nextHashCode(); private static AtomicInteger nextHashCode = new AtomicInteger(); private static final int HASH_INCREMENT = 0x61c88647;
ThreadLocal的构造方法线程
public ThreadLocal() {}
ThreadLocal的内部类
ThreadLocalMap:code
咱们第一眼就看到ThreadLocalMap中又有一个内部类Entry,好,咱们一个一个看。htm
Entry:对象
Entry就是ThreadLocalMap中实际存放数据的单个节点,为了便于理解,咱们能够参照HashMap中的Node节点。Entry组成的数组就是ThreadLocalMap的底层封装数据的数据结构。
Entry继承于WeakReference(弱引用),对于弱引用,咱们先作个大概的了解。
若是一个对象仅被WeakReference指向,而没有其余任何强引用指向的话,在下一次GC的时候,弱引用指向的对象就会被回收。
//ThreadLocalMap的map中定义内部类Entry,Entry就是具体存储数据的结构 //Entry继承了弱引用 //Entry的key是啥?是ThreadLocal的弱引用 static class Entry extends WeakReference<ThreadLocal<?>> { //存放的数据 Object value; //Entry的构造方法 Entry(ThreadLocal<?> k, Object v) { super(k); value = v; } }
Entry中有两个成员变量,一个是Ojbect类型的value,还有一个是继承于WeakReference的类型为ThreadLocal的reference。咱们能够把reference看作是key。
接着继续看ThreadLocalMap中的成员变量和构造方法。
static class ThreadLocalMap { //节点数组的初始化容量值 private static final int INITIAL_CAPACITY = 16; //Entry节点数组,存放数据的数组 private Entry[] table; //Entry数组中实际存储数据的数目,初始为0 private int size = 0; //Entry数组扩容的阈值 private int threshold; //ThreadLocalMap的构造方法 ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) { //初始化Entry数组,容量为默认的初始值16 table = new Entry[INITIAL_CAPACITY]; //threadLocalHashCode = nextHashCode(), //INITIAL_CAPACITY为16,因此(INITIAL_CAPACITY - 1)的二进制形式为1111, //与(INITIAL_CAPACITY - 1)进行位与运算就是至关于threadLocalHashCode对16取模 //这是由于Entry数组是一个长度为16的数组圆环,而key的落脚点便是在这个HashCode对16取模的值 //i就是当前这个key在Entry环形数组的索引值 int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1); //将ThreadLocal和value值构建成一个Entry,放置在ENtry数组中, table[i] = new Entry(firstKey, firstValue); //由于是构造方法,这里确定是第一次存入数据,因此size为1 size = 1; //设置entry数组的阈值,阈值为当前Entry数组长度的三分之二 setThreshold(INITIAL_CAPACITY); } //这个方法是ThreadLocal的方法 private static int nextHashCode() { //nextHashCode为AtomicInteger类型 //AtomicInteger的getAndAdd()方法就是以用Unsafe的设置方式去更新这个AtomicInteger //更新为当前值+HASH_INCREMENT return nextHashCode.getAndAdd(HASH_INCREMENT); } //这个方法是AtomicInteger的方法 public final int getAndAddInt(Object var1, long var2, int var4) { int var5; do { //var5即为当前这个AtomicInteger的值 var5 = this.getIntVolatile(var1, var2); } while(!this.compareAndSwapInt(var1, var2, var5, var5 + var4)); //将AtomicInteger的当前值var5更新为var5+var4,而war4即为增量 return var5; } //这个方法是Entry自己的方法 private void setThreshold(int len) { //阈值为当前entry数组长度的三分之二 threshold = len * 2 / 3; } //ThreadLocalMap的构造方法,参数为一个ThreadLocalMap private ThreadLocalMap(ThreadLocalMap parentMap) { //获取参数ThreadLocalMap中的Entry数组 Entry[] parentTable = parentMap.table; //获取参数Entry数组的长度 int len = parentTable.length; //设置阈值为数组长度的三分之二 setThreshold(len); //建立一个新的数组,将数组赋值给当前Entry数组table table = new Entry[len]; //循环遍历 for (int j = 0; j < len; j++) { //获取参数entry数组的每一个entry节点 Entry e = parentTable[j]; if (e != null) { @SuppressWarnings("unchecked") //e.get()返回引用referent,这个referent即为ThreadLocal ThreadLocal<Object> key = (ThreadLocal<Object>) e.get(); if (key != null) { //获取value Object value = key.childValue(e.value); //对key和value作完基本校验后,组建新的Entry节点 Entry c = new Entry(key, value); //计算下角标位置 int h = key.threadLocalHashCode & (len - 1); while (table[h] != null) //若是该下角标位置已经有元素了,计算下个索引位置 h = nextIndex(h, len); //直到计算出的索引位置上没有元素时,将新建的entry放到该索引位置 table[h] = c; //entry数组的元素数量加一 size++; } } } } //当前下角标i的下一个索引位置,若是达到entry数组的长度16的话,从新从0开始 private static int nextIndex(int i, int len) { return ((i + 1 < len) ? i + 1 : 0); } }
ThreadLocalMap中维护了一个初始容量为16的entry数组。这个entry数组就是存储数据的底层结构,还有一个阈值,看过HashMap底层源码的就不会对这个概念陌生,另外其实还有一个负载因子,不过这个负载因子并无声明成员变量,而是在代码中直接使用的,这个负载因子为三分之二,咱们能够看下setThreshold()这个方法,threshold = len * 2 / 3。
继续往下看,有两个方法比较重要的,是我们理解ThreadLocalMap数据结构的重要切入点。
//根据当前索引位置和数组长度获取下一个索引值 private static int nextIndex(int i, int len) { return ((i + 1 < len) ? i + 1 : 0); } //根据当前索引位置和数组长度获取上一个索引值 private static int prevIndex(int i, int len) { return ((i - 1 >= 0) ? i - 1 : len - 1); }
咱们看nextIndex()方法,当当前索引值加1,若是小于数组长度i+1,不然返回0。就是说若是当前索引值加一等于数组的长度就返回0。咱们想到了啥?圆钟,23点再加一个小时等于24点,24就为一天的中时数,而24点也是零点,起点。咱们会想到Entry数组是一个环形状。再看nextIndex()方法,当前索引值减1后若是小于0,返回数组的长度减1,即15,就是i等于0的时候,i减一不是等于负一,而是十五,这个时候咱们能够确认entry数组就是一个环形结构。使用线性探测法来解决散列冲突的。
下图即为Entry数组的结构图
图片来源于:https://www.cnblogs.com/micrari/p/6790229.html
Entry数组上每一个节点为一个Entry,每一个Entry由一个指向ThreadLocal的的弱引用为key,value即为咱们设置的变量值。
这里再想下怎么经过Key(ThreadLocal)来计算索引值?
这个计算索引值不是经过相似key.hashCode()这种方式来计算的,而是根据类型为AtomicInteger的nextHashCode成员变量和增量值HASH_INCREMENT成员变量来计算的,计算方式就是经过nextHashCode加上HASH_INCREMENT值的和与Entry数组长度的位与运算来计算的。如代码所示。
int i = key.threadLocalHashCode & (table.length - 1); private final int threadLocalHashCode = nextHashCode(); private static int nextHashCode() { return nextHashCode.getAndAdd(HASH_INCREMENT); }
理解了Entry数组的数据结构,咱们继续看ThreadLocalMap提供的主要方法。
获取:private Entry getEntry(ThreadLocal<?> key)
//根据key值获取Entry节点 private Entry getEntry(ThreadLocal<?> key) { //根据key值计算索引位置 int i = key.threadLocalHashCode & (table.length - 1); //获取entry数组中该索引位置的Entry节点 Entry e = table[i]; if (e != null && e.get() == key) //若是e不为null而且e的Reference(ThreadLocal)与key相同,直接返回e节点 return e; else //若是根据计算出的索引值没有找到Entry节点 return getEntryAfterMiss(key, i, e); } private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) { Entry[] tab = table; int len = tab.length; while (e != null) { //若是e不为null //获取entry的key,即ThreadLocal ThreadLocal<?> k = e.get(); if (k == key) //若是和key相等直接返回该元素 return e; if (k == null) //若是k为null,清理无效的entry,或者说清理ThreadLocal已经被回收的entry expungeStaleEntry(i); else i = nextIndex(i, len); e = tab[i]; } //若是e为null,就直接返回null了 return null; } //该方法主要作了两件事 //第一将索引为staleSlot的节点entry的value置为null,而且将entry置为null,有利于垃圾回收 //第二从索引stateSlot的下一个索引处开始遍历判断每一个entry的ThreadLocal是否为null,若是为null,将 //该entry的value和entry自己置为null,若是不为null,进行rehash从新计算索引值,判断从新计算出来的 //索引值和当前循环的索引值是否相等,若是相等,进入下一个循环,若是不等,在环形索引中寻找为节点为空的 //下角标,将e节点放置在这个索引位置 private int expungeStaleEntry(int staleSlot) { //获取ThreadLocalMap的entry数组和数组的长度 Entry[] tab = table; int len = tab.length; //由于在getEntryAfterMiss方法中已经断定k==null了 //既然key为null,因此显示将key对应的value置为null tab[staleSlot].value = null; //显示将这个节点entry也置为null,置为null有助于垃圾回收 tab[staleSlot] = null; //entry数组的元素个数减一 size--; //执行Rehash直到再次遇到null值 Entry e; int i; //循环遍历,i的初始值为当前下角标stateSlot的下一个索引位置 for (i = nextIndex(staleSlot, len); //将entry数组中下角标为当前遍历的角标i的节点赋值给e (e = tab[i]) != null; //每循环完一次去获取下一个索引位置赋值给i i = nextIndex(i, len)) { //获取当前遍历的entry的key值,即ThreadLocal ThreadLocal<?> k = e.get(); if (k == null) { //若是key(threadLocal)为null,即把key对应的value和当前这个节点都置为null //有助于垃圾回收 e.value = null; tab[i] = null; size--; } else { //若是key不为null //计算索引值 int h = k.threadLocalHashCode & (len - 1); if (h != i) { //若是新计算的索引值跟如今遍历的索引值不相等 //将当前遍历的索引值对应的节点置为null tab[i] = null; // Unlike Knuth 6.4 Algorithm R, we must scan until // null because multiple entries could have been stale. //在环形索引中寻找为节点为空的下角标,将e节点放置在这个索引位置 while (tab[h] != null) h = nextIndex(h, len); tab[h] = e; } } } return i; }
ThreadLocalMap经过key(ThreadLocal)来获取Entry节点,首先经过key来计算索引值,再经过索引值获取到某个Entry。若是Entry的key与参数key相同,则直接返回这个Entry节点;若是Entry为null,则直接返回null;若是Entry不为null,可是key不相同,就走getEntryAfterMiss()这个方法。这个方法里面主要是判断entry的key(ThreadLocal)。若是key既不相等也不为null,循环遍历下个索引值对应的entry。可是若是key为null,这个时候会走expungeStaleEntry()方法了,这个方法比较重要,咱们单独来讲说。
首先咱们想象key为null表明着什么?key为threadLocal,即threadLocal为null,而threadLocal为弱引用指向的,其实这里表示为ThreadLocal被回收了,虽然ThreadLocal被回收了,可是key对应的value是跟Thread挂钩的,value可能还没被回收,因此这里咱们须要显示的将value和entry置为null,以便于垃圾回收这些对象,同时防止内存泄露。不只如此代码中还会开始遍历该entry索引后面的整个Entry数组,若是那个entry的key为null,都会显示将object和entry置为null,让其被回收,防止内存泄露。
设置值:private void set(ThreadLocal<?> key , Object value)
private void set(ThreadLocal<?> key, Object value) { //获取entry数组和数组的长度 Entry[] tab = table; int len = tab.length; //计算key值对应的索引位置 int i = key.threadLocalHashCode & (len-1); //根据计算的索引值获取对应的Entry,从该索引处开始循环向后遍历 for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) { //根据Entry获取ThreadLocal ThreadLocal<?> k = e.get(); if (k == key) { //若是key与当前entry的key相同 //直接用参数value覆盖entry中的原value e.value = value; return; } if (k == null) { //若是k为null //替换无效的entry replaceStaleEntry(key, value, i); return; } } //建立一个新的Entry节点 tab[i] = new Entry(key, value); int sz = ++size; if (!cleanSomeSlots(i, sz) && sz >= threshold) //若是元素个数大于或者等于阈值,扩容 rehash(); } private void replaceStaleEntry(ThreadLocal<?> key, Object value,int staleSlot) { Entry[] tab = table; int len = tab.length; Entry e; int slotToExpunge = staleSlot; //向索引staleSlot的前面开始循环遍历,直到tab[i]不为null //向前遍历找到最近的一个ThreadLocal为null的entry for (int i = prevIndex(staleSlot, len); (e = tab[i]) != null; i = prevIndex(i, len)) if (e.get() == null) //若是entry的key(ThreadLocal)为null //获取entry的索引值 slotToExpunge = i; //向staleSlot的后面遍历 for (int i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) { ThreadLocal<?> k = e.get(); if (k == key) { //若是entry的key等于参数key //直接覆盖entry的value值 e.value = value; //?? tab[i] = tab[staleSlot]; tab[staleSlot] = e; if (slotToExpunge == staleSlot) slotToExpunge = i; cleanSomeSlots(expungeStaleEntry(slotToExpunge), len); return; } if (k == null && slotToExpunge == staleSlot) slotToExpunge = i; } // If key not found, put new entry in stale slot tab[staleSlot].value = null; tab[staleSlot] = new Entry(key, value); // If there are any other stale entries in run, expunge them if (slotToExpunge != staleSlot) cleanSomeSlots(expungeStaleEntry(slotToExpunge), len); } private boolean cleanSomeSlots(int i, int n) { boolean removed = false; Entry[] tab = table; int len = tab.length; do { i = nextIndex(i, len); Entry e = tab[i]; if (e != null && e.get() == null) { n = len; removed = true; i = expungeStaleEntry(i); } } while ( (n >>>= 1) != 0); return removed; } private void rehash() { expungeStaleEntries(); // Use lower threshold for doubling to avoid hysteresis if (size >= threshold - threshold / 4) resize(); } private void resize() { Entry[] oldTab = table; int oldLen = oldTab.length; int newLen = oldLen * 2; Entry[] newTab = new Entry[newLen]; int count = 0; for (int j = 0; j < oldLen; ++j) { Entry e = oldTab[j]; if (e != null) { ThreadLocal<?> k = e.get(); if (k == null) { e.value = null; // Help the GC } else { int h = k.threadLocalHashCode & (newLen - 1); while (newTab[h] != null) h = nextIndex(h, newLen); newTab[h] = e; count++; } } } setThreshold(newLen); size = count; table = newTab; }
终于看完ThreadLocalMap了,咱们能够接着看ThreadLocal的代码了!。
protected T initialValue() { return null; }
设置,void set(T value);
public void set(T value) { //获取当前线程 Thread t = Thread.currentThread(); //获取当前线程的TreadLocalMap ThreadLocalMap map = getMap(t); if (map != null) //若是ThreadLocalMap不为null,直接调用ThreadLocalMap的set方法 map.set(this, value); else //若是ThreadLocalMap为,以当前线程和value值建立ThreadLocalMap createMap(t, value); } ThreadLocalMap getMap(Thread t) { return t.threadLocals; } void createMap(Thread t, T firstValue) { //建立ThreadLocalMap并用当前线程指向该map t.threadLocals = new ThreadLocalMap(this, firstValue); }
从set()方法能够看出每一个线程(Thread)有一个threadLocals变量,如代码所示:
//Thread类的成员变量 ThreadLocal.ThreadLocalMap threadLocals = null;
ThreadLocal在设置值的时候,会先判断当前线程有没有初始化ThreadLocalMap,若是没有,先根据当前thredLocal(key)和value值生成ThreadLocalMap,并用该线程的成员变量threadLocals指向这个ThreadLocalMap;若是当前线程已经关联ThreadLocalMap了,则直接经过ThreadLocalMap的set方法设置值。
获取:T get();
public T get() { //获取当前线程 Thread t = Thread.currentThread(); //获取当前线程关联的ThreadLocalMap ThreadLocalMap map = getMap(t); if (map != null) { //若是ThreadLocalMap不为null,根据key(ThreadLocal)值获取entry ThreadLocalMap.Entry e = map.getEntry(this); if (e != null) { @SuppressWarnings("unchecked") T result = (T)e.value; //获取entry的value值返回 return result; } } //不然初始化当前线程的ThreadLocalMap,value为null return setInitialValue(); } private T setInitialValue() { //value为空 T value = initialValue(); //获取当前线程 Thread t = Thread.currentThread(); ThreadLocalMap map = getMap(t); if (map != null) map.set(this, value); else createMap(t, value); return value; } protected T initialValue() { return null; }
到此,ThreadLocal的主要代码就介绍完了。
ThreadLocal是否存在内存泄露问题?
会,咱们先来看下ThreadLocal的引用和数据结构图,图片来源于:http://www.importnew.com/22039.html,map指的ThreadLocalMap,实线表明强引用,虚线表明弱引用。
咱们看到ThreadLocal有一个强引用和一个弱引用,强引用来自高层代码中的引用,好比ThreadLocal tl = new TheadLocal(),tl这就是一个强引用,而弱应用来自于ThreadLocalMap中的Entry的key的引用。当高层代码中把threadlocal实例置为null之后,就没有任何强引用指向threadlocal实例,而只有一个弱引用去指向ThreadLocal,可是咱们知道弱引用指向的对象在GC时是会被回收的,因此threadlocal将会被gc回收。这也是Entry中的key使用弱应用的缘由,不然TreadLoca就算在高层代码中释放引用后,由于Entry还存在,key仍然指向ThreadLocal,因此让不会被回收,容易形成内存泄露。
当ThreadLocal被回收后,咱们的value还不能回收,由于存在一条从current thread链接过来的强引用.,只要thread存在,这个引用就会一直存在,只有当thread结束之后, current thread才会被销毁,强引用才会断开, 此时Current Thread, Map, value才能所有被GC回收。
因此这里存在一个风险就是,在current Thread到销毁的这段时间内,存在因为value值过多或者过大致使的内存泄露问题,咱们在想下,若是咱们是使用的线程池,出现什么结果,线程用完后,直接放回线程池中,不会被销毁,那么那些value就会一直存在,这样产生内存泄露的可能性大大增长。
JDK是怎么解决这个问题的呢?
咱们回过头来在看看ThreadLocalMap的set和get方法,咱们发现代码里都会循环遍历Entry数组,检查entey中的key(ThreadLocal)是否为null,若是为null,会显示的将entry的value和entry自己置为null,这样以便entry和entry的value能被GC回收,防止内存泄露。
既然知道了内存泄露的来龙去脉,咱们在使用TheadLocal时候就要特别注意这方面的问题,好比咱们再用完TheadLocal后记得用remove()方法去清除数据。