jdk源码(二):你知道ConcurrentHashMap的具体实现细节吗?

一、首先抛出几个问题(文章最后有答案):java

a、ConcurrentHashMap在put的时候,key通过几回hash计算?node

b、segment 会增大吗?数组

c、新的值是放在链表的表头仍是表尾?安全

二、ConcurrentHashMap是如何存储数据的?数据结构

先看图:并发

从图中咱们能够看出ConcurrentHashMap有两个种数据结构:数组和单向链表ssh

那ConcurrentHashMap和如何存放一对key和value呢?高并发

put的具体过程:性能

a、根据key计算hash值优化

b、根据hash值找到segment数组的下标

c、根据上面的下标获取tab数组,

d、根据hash值,获取tab数组的下标

c、若是tab当前下标位置上没有值,就直接把存储有key和value的HashEntry存放在tab的当前下标下,不然就是造成一个链表(解决了Hash值冲突)

这就是整个put的大概过程。

是否是有小伙伴说,裤子都脱了,你给我看这个?哈哈哈哈哈,好,上代码

public V put(K key, V value) {
        Segment<K,V> s;
        if (value == null)
            throw new NullPointerException();
        int hash = hash(key); // 根据key获取hash值
        int j = (hash >>> segmentShift) & segmentMask; // 定位segment数组的下标j
        if ((s = (Segment<K,V>)UNSAFE.getObject         
             (segments, (j << SSHIFT) + SBASE)) == null) //  根据下标j获取数组该下标的元素s
            s = ensureSegment(j);
        return s.put(key, hash, value, false);
    }

看代码,你们是否是也有些不太明白呢?那我就一行一行的解释吧

解释以前咱们得先了解几个参数:(注:涉及到unsafe的相关用法参考:https://my.oschina.net/huangy/blog/1620321)

segmentShift、segmentMask、SSHIFT,SBASE,segments
public ConcurrentHashMap(int initialCapacity,
                             float loadFactor, int concurrencyLevel) {
        // 默认值initialCapacity=16 loadFactor=0.75 concurrencyLevel=16
        // initialCapacity 决定tab数组的初始化长度,concurrencyLevel决定segment数组的长度
        if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
            throw new IllegalArgumentException();
        if (concurrencyLevel > MAX_SEGMENTS)
            concurrencyLevel = MAX_SEGMENTS;
        // Find power-of-two sizes best matching arguments
        // 找到一个ssize不小于concurrencyLevel且必须是2的n次幂,为何呢?下面解释
        int sshift = 0;
        int ssize = 1;        
        while (ssize < concurrencyLevel) {
            ++sshift;
            ssize <<= 1;
        }
        this.segmentShift = 32 - sshift; // sshift=4,segmentShift=28
        this.segmentMask = ssize - 1; // segment数组的长度=ssize=16,segmentMask=15
        if (initialCapacity > MAXIMUM_CAPACITY)
            initialCapacity = MAXIMUM_CAPACITY;
        int c = initialCapacity / ssize;
        if (c * ssize < initialCapacity)
            ++c;
        int cap = MIN_SEGMENT_TABLE_CAPACITY;
        while (cap < c)
            cap <<= 1;
        // create segments and segments[0]
        Segment<K,V> s0 =
            new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
                             (HashEntry<K,V>[])new HashEntry[cap]);
        Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
        UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0]
        this.segments = ss; // segments长度=16
    }
static {
        int ss, ts;
        try {
            UNSAFE = sun.misc.Unsafe.getUnsafe();
            Class tc = HashEntry[].class;
            Class sc = Segment[].class;
            TBASE = UNSAFE.arrayBaseOffset(tc);// 获取HashEntry[]的基本偏移量=6
            SBASE = UNSAFE.arrayBaseOffset(sc);// 获取Segment[]的基本偏移量=6
            ts = UNSAFE.arrayIndexScale(tc);// 获取HashEntry[]单位偏移量=4
            ss = UNSAFE.arrayIndexScale(sc);//获取Segment[]单位偏移量=4
            HASHSEED_OFFSET = UNSAFE.objectFieldOffset(
                ConcurrentHashMap.class.getDeclaredField("hashSeed"));
            SEGSHIFT_OFFSET = UNSAFE.objectFieldOffset(
                ConcurrentHashMap.class.getDeclaredField("segmentShift"));
            SEGMASK_OFFSET = UNSAFE.objectFieldOffset(
                ConcurrentHashMap.class.getDeclaredField("segmentMask"));
            SEGMENTS_OFFSET = UNSAFE.objectFieldOffset(
                ConcurrentHashMap.class.getDeclaredField("segments"));
        } catch (Exception e) {
            throw new Error(e);
        }
        if ((ss & (ss-1)) != 0 || (ts & (ts-1)) != 0)
            throw new Error("data type scale not a power of two");
        SSHIFT = 31 - Integer.numberOfLeadingZeros(ss);//把Segment[]单位偏移量转成移位的次数=2
        TSHIFT = 31 - Integer.numberOfLeadingZeros(ts);//把HashEntry[]单位偏移量转成移位的次数=2
    }

开始解释代码: 

public V put(K key, V value) {
        Segment<K,V> s;
        if (value == null)
            throw new NullPointerException();
        int hash = hash(key); // 根据key获取hash值

        int j = (hash >>> segmentShift) & segmentMask; // 定位segment数组的下标j
        // 上面知道segmentShift=28,segmentMask=15(二进制:00000000000000000000000000001111)
        //  假设hash=994162679 二进制:00111011010000011011011111110111
        //  (hash >>> segmentShift)
        //  这句话的意思就是hash右移28 结果=3 (二级制:00000000000000000000000000000011)
        //  int j = (hash >>> segmentShift) & segmentMask;
        //  就变成
        //  int j = 3 & segmentMask;
        //  00000000000000000000000000000011 & 00000000000000000000000000001111
        //  的结果=00000000000000000000000000000011
        //  因此 j = 3
        //  这就根据hash值定位segments的数组下标 j=3
        //  咱们回想一下:segments数组长度ssize=16
        //  segmentMask = ssize-1 = 15
        //  而后每一个通过右移28位的hash值和segmentMask进行与操做,
        //  就能保证j必定落在数组内,保证了不越界,同时效率也很是高
        //  这就是要找到一个ssize不小于concurrencyLevel且必须是2的n次幂的缘由
        //  HashMap 也是这么干的,能够关注我查看个人相关文章。

        if ((s = (Segment<K,V>)UNSAFE.getObject         
             (segments, (j << SSHIFT) + SBASE)) == null) //  根据下标j获取数组该下标的元素s  
        //  UNSAFE.getObject(segments, (j << SSHIFT) + SBASE))
        //  根据上面知道:SSHIFT=2 SBASE=6
        //  同时刚刚求出 j=3
        //  (j << SSHIFT) + SBASE 的结果是12+6=20
        //  UNSAFE.getObject()会根据segments和偏移量获得数组下标=3的元素
        //  UNSAFE.getObject(segments, (j << SSHIFT) + SBASE)) 能够简单认为是 segments[3]
        //  只不过这是直接内存操做很是高效而已,这样的操做ConcurrentHashMap用的很是多
        //  不明白能够查看个人相关文章 
            s = ensureSegment(j);
        return s.put(key, hash, value, false);
    }
private Segment<K,V> ensureSegment(int k) {
        // 这里一系列的unsafe操做请查看个人相关文章
        final Segment<K,V>[] ss = this.segments;
        long u = (k << SSHIFT) + SBASE; //仍是获取偏移量
        Segment<K,V> seg;
        if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
            Segment<K,V> proto = ss[0]; // use segment 0 as prototype
            int cap = proto.table.length; // cap=16
            float lf = proto.loadFactor; // lf =0.75
            int threshold = (int)(cap * lf);
            HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
            if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
                == null) { // recheck
                Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
                while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
                       == null) {
                    // UNSAFE.compareAndSwapObject 保证了原子性,能够思考一下没有原子保证,会有什么后果
                    if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
                        break;
                }
            }
        }
        return seg;
    }

分析:s.put(key, hash, value, false);

分析以前咱们了解Segment,它继承了ReentrantLock,咱们知道ConcurrentHashMap 是线程安全的,这就是关键点了

s.put()执行过程

a、尝试获取锁

b、根据hash值获取tab数组下标

c、tab数组当前下标,是否有HashEntry,有则遍历,没有则建立一个HashEntry

d、释放锁

大概就是这么一个过程

看代码

final V put(K key, int hash, V value, boolean onlyIfAbsent) {
            // 尝试获取锁
            HashEntry<K,V> node = tryLock() ? null :
                scanAndLockForPut(key, hash, value);
            V oldValue;
            try {
                HashEntry<K,V>[] tab = table;
                int index = (tab.length - 1) & hash;
               // tab.length=16 16-1=15(二进制00000000000000000000000000001111)
               // 看到这个是否是很熟悉,不解释了
                HashEntry<K,V> first = entryAt(tab, index);// 根据数组和下标获取元素,仍是unsafe操做
                //遍历一次就搞定了全部事情,若是是你写,你会怎么写?
                for (HashEntry<K,V> e = first;;) {
                    if (e != null) {
                        // e,若是不为空,则表示tab数组当前这个下标,已经有值,极可能造成一个链表,
                        // 
                        K k;
                        if ((k = e.key) == key ||
                            (e.hash == hash && key.equals(k))) {
                            // 若是key和hash都相同,表示同一个,按要求看是否要更新value
                            oldValue = e.value;
                            if (!onlyIfAbsent) {
                                e.value = value;
                                ++modCount;
                            }
                            break;
                        }
                        e = e.next;
                    }
                    else {
                        if (node != null)
                            node.setNext(first);// 从这里能够看出新进来的,会作链表头
                        else
                            node = new HashEntry<K,V>(hash, key, value, first);
                        int c = count + 1;
                        if (c > threshold && tab.length < MAXIMUM_CAPACITY)
                            rehash(node);
                        else
                            setEntryAt(tab, index, node);
                        ++modCount;
                        count = c;
                        oldValue = null;
                        break;
                    }
                }
            } finally {
                unlock();
            }
            return oldValue;
        }

其实最让人膜拜的代码是:

HashEntry<K,V> node = tryLock() ? null :scanAndLockForPut(key, hash, value);

tryLock()尝试获取锁,获取不到就执行scanAndLockForPut,咱们看看scanAndLockForPut都干吗了

private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
            // entryForHash根据hash值获取tab中对应的元素,看不懂能够参考以前的分析
            HashEntry<K,V> first = entryForHash(this, hash);
            HashEntry<K,V> e = first;
            HashEntry<K,V> node = null;
            int retries = -1; // negative while locating node
            // 不断的在尝试获取锁
            while (!tryLock()) {
                HashEntry<K,V> f; // to recheck first below
                if (retries < 0) {
                    if (e == null) {
                        if (node == null) // speculatively create node
                            node = new HashEntry<K,V>(hash, key, value, null);
                        retries = 0;
                    }
                    else if (key.equals(e.key))
                        retries = 0;
                    else
                        e = e.next;
                }
                else if (++retries > MAX_SCAN_RETRIES) {
                // 为了防止无限次尝试,作了个限制,通常MAX_SCAN_RETRIES=64
                    lock();
                    break;
                }
                else if ((retries & 1) == 0 &&
                         (f = entryForHash(this, hash)) != first) {
                    e = first = f; // re-traverse if entry changed
                    retries = -1;
                }
            }
            return node;
        }
        // 这段代码主要作两件事:
        // 一、获取锁,执行完这个方法确定能获得锁
        // 二、在获取锁等待的过程当中,有必要的建立新HashEntry
        // 这段代码主要优化的是:
        // 利用在获取锁等待的时间,若是发现tab当前这个下标的值为空
        // 那么建立HashEntry,而后继续获取锁,知道超过MAX_SCAN_RETRIES的次数
        // 执行到lock(),而后整个线程就会进入等待
        // 若是不是使用tryLock(),而是上来就是lock(),那么整个线程就会进入等待,什么都干不了
        // 这是一个很是小的优化,可是绝大部分应用场景都是新建立HashEntry这样的状况的,
        // 因此这个优化仍是很是值得确定的
        // 大神写的代码,是否是有种眼界提升的感受,哈哈哈哈哈哈

到此,ConcurrentHashMap的最关键的代码就是这些了,只要你能看懂这些,其余的都不在话下。

Unsafe的相关操做参考:

https://my.oschina.net/huangy/blog/1620321

三、总结

开头的问题有答案了吗?

好,揭晓答案

a、ConcurrentHashMap在put的时候,key通过几回hash计算?一次

b、segment 会增大吗?不会

c、新的值是放在链表的表头仍是表尾?表头

public ConcurrentHashMap(int initialCapacity,float loadFactor, int concurrencyLevel)

initialCapacity:控制tab数组的大小(默认16)

loadFactor:tab进行rehash阈值百分数(默认0.75)

concurrencyLevel:控制segment的大小 (默认16)

因此,一旦concurrencyLevel指定了就不能改变了

那么ConcurrentHashMap里为何分segment呢?

这就是ConcurrentHashMap高明之处,经过以前的分析咱们都知道锁只在segment中存在,这样就把锁的粒度变小,提升并发,同时仍是线程安全的,

因此,若是咱们使用ConcurrentHashMap存放数据的时候,数据很是大的时候,concurrencyLevel的指定就尤其重要了,合适concurrencyLevel的可让ConcurrentHashMap性能最佳。

最后留个问题:

hash值相同的两个对象是同一个吗?欢迎你们评论里留言

欢迎关注,转发

持续更新有意思的代码

相关文章
相关标签/搜索