ThreadLocal源码解析

ThreadLocal的作用

ThreadLocal的作用是提供线程内的局部变量,说白了,就是在各线程内部创建一个变量的副本,相比于使用各种锁机制访问变量,ThreadLocal的思想就是用空间换时间,使各线程都能访问属于自己这一份的变量副本,变量值不互相干扰,减少同一个线程内的多个函数或者组件之间一些公共变量传递的复杂度。

ThreadLocal特性及使用场景

  • 1、方便同一个线程使用某一对象,避免不必要的参数传递;
  • 2、线程间数据隔离(每个线程在自己线程里使用自己的局部变量,各线程间的ThreadLocal对象互不影响);
  • 3、获取数据库连接、Session、关联ID(比如日志的uniqueID,方便串起多个日志);

ThreadLocal应注意

  • 1、ThreadLocal并未解决多线程访问共享对象的问题;
  • 2、ThreadLocal并不是每个线程拷贝一个对象,而是直接new(新建)一个;
  • 3、如果ThreadLocal.set()的对象是多线程共享的,那么还是涉及并发问题。

图解TreadLocal

每个线程可能有多个ThreadLocal,同一线程的各个ThreadLocal存放于同一个ThreadLocalMap中。

图解ThreadLocal(JDK8).vsdx原图下载地址:https://github.com/zxiaofan/JDK-Study/tree/master/src/java1/lang/threadLocal

内部类

ThreadLocalMap

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
static class ThreadLocalMap {


static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;

/**
* ThreadLocalMap的key是ThreadLocal
* value是Object(即我们所谓的“线程本地数据”)
*/
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}

/**
* 初始容量,2的幂等次方
*/
private static final int INITIAL_CAPACITY = 16;

/**
* 实际保存数据的数组,超过threshold会2倍扩容
*/
private Entry[] table;

/**
* 实际存储的entry数量
*/
private int size = 0;

/**
* 下次扩容的阈值
*/
private int threshold; // Default to 0

/**
* Set the resize threshold to maintain at worst a 2/3 load factor.
*/
private void setThreshold(int len) {
threshold = len * 2 / 3;
}

/**
* 往后移动一位
*/
private static int nextIndex(int i, int len) {
return ((i + 1 < len) ? i + 1 : 0);
}

/**
* 往前移动一位
*/
private static int prevIndex(int i, int len) {
return ((i - 1 >= 0) ? i - 1 : len - 1);
}

/**
* Construct a new map initially containing (firstKey, firstValue).
* ThreadLocalMaps懒汉模式, 等第一个entry被放入时才初始化.
*/
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
table = new Entry[INITIAL_CAPACITY];
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
table[i] = new Entry(firstKey, firstValue);
size = 1;
setThreshold(INITIAL_CAPACITY);
}

/**
* 将父线程的ThreadLocalMaps内容复制过来
* Called only by createInheritedMap.
*/
private ThreadLocalMap(ThreadLocalMap parentMap) {
Entry[] parentTable = parentMap.table;
int len = parentTable.length;
setThreshold(len);
table = new Entry[len];

for (int j = 0; j < len; j++) {
Entry e = parentTable[j];
if (e != null) {
@SuppressWarnings("unchecked")
ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
if (key != null) {
Object value = key.childValue(e.value);
Entry c = new Entry(key, value);
int h = key.threadLocalHashCode & (len - 1);
while (table[h] != null)
h = nextIndex(h, len);
table[h] = c;
size++;
}
}
}
}
}

ThreadLocalMap是定制的hashMap,仅用于维护当前线程的本地变量值。仅ThreadLocal类对其有操作权限,是Thread的私有属性。为避免占用空间较大或生命周期较长的数据常驻于内存引发一系列问题,hash table的key是弱引用WeakReferences。当空间不足时,会清理未被引用的entry。这时Entry里的key为null了,那么直到线程结束前,Entry中的value都是无法回收的,这里可能产生内存泄露

SuppliedThreadLocal

1
2
3
4
5
6
7
8
9
10
11
12
13
static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {

private final Supplier<? extends T> supplier;

SuppliedThreadLocal(Supplier<? extends T> supplier) {
this.supplier = Objects.requireNonNull(supplier);
}

@Override
protected T initialValue() {
return supplier.get();
}
}

SuppliedThreadLocal是JDK8新增的内部类,只是扩展了ThreadLocal的初始化值的方法而已,允许使用JDK8新增的Lambda表达式赋值。需要注意的是,函数式接口Supplier不允许为null。

初始化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
public class ThreadLocal<T> {

/**
* ThreadLocal初始化时会调用nextHashCode()方法初始化
* threadLocalHashCode,且threadLocalHashCode初始化后不可变。
* threadLocalHashCode可用来标记不同的ThreadLocal实例。
*/
private final int threadLocalHashCode = nextHashCode();

private static AtomicInteger nextHashCode =
new AtomicInteger();

private static final int HASH_INCREMENT = 0x61c88647;

private static int nextHashCode() {
return nextHashCode.getAndAdd(HASH_INCREMENT);
}


protected T initialValue() {
return null;
}

/**
* JDK8新增,支持Lambda表达式,和ThreadLocal重写的initialValue()效果一样。
*/
public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
return new SuppliedThreadLocal<>(supplier);
}

public ThreadLocal() {
}
}

ThreadLocal类变量有3个,其中2个是静态变量(包括一个常量),实际作为作为ThreadLocal实例的变量只有threadLocalHashCode这1个,而且已经初始化就不可变了。

其中withInitial()方法使用示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
public void jdk8Test(){
Supplier<String> supplier =new Supplier<String>(){
@Override
public String get(){
return"supplier_new";
}
};

threadLocal= ThreadLocal.withInitial(supplier);
System.out.println(threadLocal.get()); // supplier_new

// Lambda表达式
threadLocal= ThreadLocal.withInitial(()->"sup_new_2");
System.out.println(threadLocal.get()); // sup_new_2

ThreadLocal<DateFormat> localDate = ThreadLocal.withInitial(()->new SimpleDateFormat("yyyy-MM-dd"));
System.out.println(localDate.get().format(new Date())); // 2017-01-22

ThreadLocal<String> local =new ThreadLocal<>().withInitial(supplier);
System.out.println(local.get()); // supplier_new
}

源码分析

get方法

1
2
3
4
5
6
7
8
9
10
11
12
13
public T get() {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
return setInitialValue();
}

直接看代码,可以分析主要有以下几步:

    1. 获取当前的Thread对象,通过getMap获取Thread内的ThreadLocalMap
    1. 如果map已经存在,以当前的ThreadLocal为键,获取Entry对象,并从从Entry中取出值
    1. 否则,调用setInitialValue进行初始化。

getMap

1
2
3
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}

getMap很简单,就是返回线程中ThreadLocalMap,跳到Thread源码里看,ThreadLocalMap是这么定义的:

1
ThreadLocal.ThreadLocalMap threadLocals = null;

所以ThreadLocalMap还是定义在ThreadLocal里面的,我们前面已经说过ThreadLocalMap中的Entry定义,下面为了先介绍ThreadLocalMap的定义我们把setInitialValue放在前面说。

setInitialValue

1
2
3
4
5
6
7
8
9
10
private T setInitialValue() {
T value = initialValue();
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
return value;
}

setInititialValue在Map不存在的时候调用。

  1. 首先是调用initialValue生成一个初始的value值,深入initialValue函数,我们可知它就是返回一个null,如果创建ThreadLocal时调用withInitial() 方法指定了初始方法,则返回自定义值

  2. 还是在get()一下Map,如果map存在,则直接map.set(), 这个函数会放在后文说;

  1. 如果map不存在,则会调用createMap()创建ThreadLocalMap。

createMap

1
2
3
void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);
}

比较简单,就是调用了ThreadLocalMap内部类的构造函数而已。

map.getEntry

1
2
3
4
5
6
7
8
private Entry getEntry(ThreadLocal<?> key) {
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
if (e != null && e.get() == key)
return e;
else
return getEntryAfterMiss(key, i, e);
}
  1. 首先是计算索引位置i,通过计算key的hash%(table.length-1)得出;
  2. 根据获取Entry,如果Entry存在且Entry的key恰巧等于ThreadLocal,那么直接返回Entry对象;
  3. 否则,也就是在此位置上找不到对应的Entry,那么就调用getEntryAfterMiss。

getEntryAfterMiss

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;

while (e != null) {
ThreadLocal<?> k = e.get();
if (k == key)
return e;
if (k == null)
expungeStaleEntry(i);
else
i = nextIndex(i, len);
e = tab[i];
}
return null;
}

这个方法我们还得结合上一步看,上一步是因为不满足

1
e != null && e.get() == key

才沦落到调用getEntryAfterMiss的,所以:

  • 首先e如果为null的话,证明不存在value, 那么getEntryAfterMiss还是直接返回null的

  • 如果是不满足e.get() == key,那么进入while循环,这里是不断循环,如果e一直不为空,那么就调用nextIndex,不断递增i,在此过程中一直会做两个判断:

    • 如果 k == key, 那么代表找到了这个所需要的Entry,直接返回;

    • 如果 k == null,那么证明这个Entry中key已经为null, 那么这个Entry就是一个过期对象,这里调用expungeStaleEntry清理该Entry。这里解答了前面留下的一个坑,即ThreadLocal Ref销毁时,ThreadLocal实例由于只有Entry中的一条弱引用指着,那么就会被GC掉,Entry的key没了,value可能会内存泄露的,其实在每一个get,set操作时都会不断清理掉这种key为null的Entry的

为什么循环查找?

这里你可以直接跳到下面的set方法,主要是因为处理哈希冲突的方法,我们都知道HashMap采用拉链法处理哈希冲突,即在一个位置已经有元素了,就采用链表把冲突的元素链接在该元素后面,而ThreadLocal采用的是开放地址法,即有冲突后,把要插入的元素放在要插入的位置后面为null的地方

具体关于这两种方法的区别可以参考:解决哈希(HASH)冲突的主要方法

所以上面的循环就是因为我们在第一次计算出来的i位置不一定存在key与我们想查找的key恰好相等的Entry,所以只能不断在后面循环,来查找是不是被插到后面了,直到找到为null的元素,因为若是插入也是到null为止的。

expungeStaleEntry

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;

// (1)删掉staleSlot位置value值
tab[staleSlot].value = null;
tab[staleSlot] = null;
size--;

// (2)Rehash until we encounter null
Entry e;
int i;
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null;
tab[i] = null;
size--;
} else {
// 删除元素后,需要重新移动存活的元素,因为查找时遇到null会终止
int h = k.threadLocalHashCode & (len - 1);
if (h != i) {
tab[i] = null;

while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
return i;
}

看上面这段代码主要有两部分:

  • (1) 这段主要是将i位置上的Entry的value设为null,Entry的引用也设为null,那么系统GC的时候自然会清理掉这块内存;

  • (2) 这段就是扫描位置staleSlot之后,null之前的Entry数组,清除每一个key为null的Entry,同时若是key不为空,做rehash,调整其位置。

为什么要做rehash呢?

因为我们在清理的过程中会把某个值设为null,那么这个值后面的区域如果之前是连着前面的,那么下次循环查找时,就会只查到null为止。

举个例子就是:

…, <key1(hash1), value1>, <key2(hash1), value2>,…

即key1和key2的hash值相同, 此时,若插入

<key3(hash2), value3>

其hash计算的目标位置被

<key2(hash1), value2>

占了,于是往后寻找可用位置,hash表可能变为:

…, <key1(hash1), value1>, <key2(hash1), value2>, <key3(hash2), value3>, …

此时,若

<key2(hash1), value2>

被清理,显然

<key3(hash2), value3>
应该往前移(即通过rehash调整位置),否则若以key3查找hash表,将会找不到key3。

set方法

我们在get方法的循环查找那里也大概描述了set方法的思想,即开放地址法,下面看具体代码:

1
2
3
4
5
6
7
8
public void set(T value) {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
}

首先也是获取当前线程,根据线程获取到ThreadLocalMap,若是有ThreadLocalMap,则调用

1
map.set(ThreadLocal<?> key, Object value)

若是没有则调用createMap创建。

map.set

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
private void set(ThreadLocal<?> key, Object value) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);

for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();

if (k == key) {
e.value = value;
return;
}

if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}

tab[i] = new Entry(key, value);
int sz = ++size;
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}

看上面这段代码:

  1. 首先还是根据key计算出位置i,然后查找i位置上的Entry,

  2. 若是Entry已经存在并且key等于传入的key,那么这时候直接给这个Entry赋新的value值。

  3. 若是Entry (e != null) 存在,但是key为null,则调用replaceStaleEntry来更换这个key为空的Entry

  4. 不断循环检测,直到遇到为null的地方,这时候要是还没在循环过程中return,那么就在这个null的位置新建一个Entry,并且插入,同时size增加1。

  5. 最后调用cleanSomeSlots,这个函数就不细说了,你只要知道内部还是调用了上面提到的expungeStaleEntry函数清理key为null的Entry就行了,最后返回是否清理了Entry,接下来再判断 sz>thresgold ,这里就是判断是否达到了rehash的条件,达到的话就会调用rehash函数。

上面这段代码有两个函数还需要分析下,首先是:

replaceStaleEntry

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
int staleSlot) {
Entry[] tab = table;
int len = tab.length;
Entry e;

// 向前找到key为null的位置
int slotToExpunge = staleSlot;
for (int i = prevIndex(staleSlot, len);
(e = tab[i]) != null;
i = prevIndex(i, len))
if (e.get() == null)
slotToExpunge = i;

// staleSlot节点key为空,属于应该清理节点
for (int i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
if (k == key) {
e.value = value; // 更新value值

tab[i] = tab[staleSlot]; // i指向key为空节点
tab[staleSlot] = e;

// staleSlot前面全不为空,i节点指向最新key为null位置
if (slotToExpunge == staleSlot)
slotToExpunge = i;
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}

// 更新key为空节点位置
if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i;
}

// If key not found, put new entry in stale slot
tab[staleSlot].value = null;
tab[staleSlot] = new Entry(key, value);

// If there are any other stale entries in run, expunge them
if (slotToExpunge != staleSlot)
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

首先我们回想上一步是因为这个位置的Entry的key为null才调用replaceStaleEntry。

  1. 第1个for循环:我们向前找到key为null的位置,记录为slotToExpunge,这里是为了后面的清理过程,可以不关注了;

  2. 第2个for循环:我们从staleSlot起到下一个null为止,若是找到key和传入key相等的Entry,就给这个Entry赋新的value值,并且把它和staleSlot位置的Entry交换,然后调用CleanSomeSlots清理key为null的Entry。

  3. 若是一直没有key和传入key相等的Entry,那么就在staleSlot处新建一个Entry。函数最后再清理一遍空key的Entry。

说完replaceStaleEntry,还有个重要的函数是rehash以及rehash的条件:

首先是sz > threshold时调用rehash

rehash

1
2
3
4
5
6
7
8
private void rehash() {
// 清理全部空节点
expungeStaleEntries();

// Use lower threshold for doubling to avoid hysteresis
if (size >= threshold - threshold / 4)
resize();
}

清理完空key的Entry后,如果size大于3/4的threshold,则调用resize函数:

resize

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
private void resize() {
Entry[] oldTab = table;
int oldLen = oldTab.length;
int newLen = oldLen * 2;
Entry[] newTab = new Entry[newLen];
int count = 0;

for (int j = 0; j < oldLen; ++j) {
Entry e = oldTab[j];
if (e != null) {
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null; // Help the GC 下次gc会被回收
} else {
int h = k.threadLocalHashCode & (newLen - 1);
while (newTab[h] != null)
h = nextIndex(h, newLen);
newTab[h] = e;
count++;
}
}
}

setThreshold(newLen);
size = count;
table = newTab;
}

由源码我们可知每次扩容大小扩展为原来的2倍,然后再一个for循环里,清除空key的Entry,同时重新计算key不为空的Entry的hash值,把它们放到正确的位置上,再更新ThreadLocalMap的所有属性。

remove

最后一个需要探究的就是remove函数,它用于在map中移除一个不用的Entry。也是先计算出hash值,若是第一次没有命中,就循环直到null,在此过程中也会调用expungeStaleEntry清除空key节点。代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
e.clear();
expungeStaleEntry(i);
return;
}
}
}

使用ThreadLocal的最佳实践

我们发现无论是set,get还是remove方法,过程中key为null的Entry都会被擦除,那么Entry内的value也就没有强引用链,GC时就会被回收。那么怎么会存在内存泄露呢?但是以上的思路是假设你调用get或者set方法了,很多时候我们都没有调用过,所以最佳实践就是:

  1. 使用者需要手动调用remove函数,删除不再使用的ThreadLocal.
  2. 尽量将ThreadLocal设置成private static的,这样ThreadLocal会尽量和线程本身一起消亡。

问题与思考

(1)如果有多个ThreadLocal都对同一个线程ThreadLocalMap写数据时,可能存在hash位置冲突,导致set()和get()效率显著下降;

(2)ThreadLocal不能读取父线程的ThradLocalMap内容,需要使用InheritableThreadLocal;