本篇文章我来讨论一下什么是ThreadLocal以及它的实现原理。其底层数据结构有点相似HashMap,因此对HashMap不熟悉的朋友能够先去看一看我前面介绍HashMap的那篇文章。java
本文如如有不对或不实之处,也欢迎各位读者朋友评论指正,欢迎探讨交流。
数组
ThreadLocal提供了线程的局部变量,每一个线程均可以经过set()和get()来对这个局部变量进行操做,但不会和其余线程的局部变量进行冲突,实现了线程的数据隔离。安全
static class ThreadLocalMap {
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
//...
}复制代码
咱们的值都是存储到这个Map上的,key是当前ThreadLocal对象!数据结构
void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);
}复制代码
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
复制代码
/* ThreadLocal values pertaining to this thread. This map is maintained * by the ThreadLocal class. */
ThreadLocal.ThreadLocalMap threadLocals = null复制代码
在ThreadLoalMap中,也是初始化一个大小16的Entry数组,Entry对象用来保存每个key-value键值对,只不过这里的key永远都是ThreadLocal对象,经过ThreadLocal对象的set方法,结果把ThreadLocal对象本身当作key,放进了ThreadLoalMap中。并发
public void set(T value) {
// 获得当前线程对象
Thread t = Thread.currentThread();
// 这里获取ThreadLocalMap
ThreadLocalMap map = getMap(t);
// 若是map存在,则将当前线程对象t做为key,要存储的对象
//做为value存到map里面去
if (map != null)
map.set(this, value);
else
createMap(t, value);
}
private void set(ThreadLocal key, Object value) {
Entry[] tab = table;
int len = tab.length;
//根据hash值计算存放下标i
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal k = e.get();
if (k == key) {
e.value = value;
return;
}
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
tab[i] = new Entry(key, value);
int sz = ++size;
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}
复制代码
这样的话,在get的时候,也会根据ThreadLocal对象的hash值,定位到table中的位置,而后判断该位置Entry对象中的key是否和get的key一致,若是不一致,就判断下一个位置,能够发现,set和get若是冲突严重的话,效率很低。工具
public T get() {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
return setInitialValue();
}
private Entry getEntry(ThreadLocal<?> key) {
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
if (e != null && e.get() == key)
return e;
else
return getEntryAfterMiss(key, i, e);
}
//查找对应key的 Entry
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;
while (e != null) {
ThreadLocal<?> k = e.get();
if (k == key)
return e;
if (k == null)
expungeStaleEntry(i);
else
i = nextIndex(i, len);
e = tab[i];
}
return null;
} 复制代码