面試不再慌,看完這篇保證讓你寫HashMap跟玩一樣

今天這篇文章給大家講講hashmap,這個號稱是所有Java工程師都會的數據結構。爲什麼說是所有Java工程師都會呢,因爲很簡單,他們不會這個找不到工作。幾乎所有面試都會問,基本上已經成了標配了。

在今天的這篇文章當中我們會揭開很多謎團。比如,爲什麼hashmap的get和put操作的複雜度是,甚至比紅黑樹還要快?hashmap和hash算法究竟是什麼關係?hashmap有哪些參數,這些參數分別是做什麼用的?hashmap是線程安全的嗎?我們怎麼來維護hashmap的平衡呢?

讓我們帶着疑問來看看hashmap的基本結構。

基本結構

hashmap這個數據結構其實並不難,它的結構非常非常清楚,我用一句話就可以說明,其實就是鄰接表。雖然這兩者的用途迥然不同,但是它們的結構是完全一樣的。說白了就是一個定長的數組,這個數組的每一個元素都是一個鏈表的頭結點。我們把這個結構畫出來,大家一看就明白了。

headers是一個定長的數組,數組當中的每一個元素都是一個鏈表的頭結點。也就是說根據這個頭結點,我們可以遍歷這個鏈表。數組是定長的,但是鏈表是變長的,所以如果我們發生元素的增刪改查,本質上都是通過鏈表來實現的。

這個就是hashmap的基本結構,如果在面試當中問到,你可以直接回答:它本質上就是一個元素是鏈表的數組。

hash的作用

現在我們搞明白了hashmap的基本結構,現在進入下一個問題,這麼一個結構和hash之間有什麼關係呢?

其實也不難猜,我們來思考一個場景。假設我們已經擁有了一個hashmap,現在新來了一份數據需要存儲。上圖當中數組的長度是6,也就是說有6個鏈表可供選擇,那麼我們應該把這個新來的元素放在哪個鏈表當中呢?

你可能會說當然是放在最短的那個,這樣鏈表的長度才能平衡。這樣的確不錯,但是有一個問題,這樣雖然存儲方便了,但是讀取的時候卻有很大的問題。因爲我們存儲的時候知道是存在最短的鏈表裏了,但是當我們讀取的時候,我們是不知道當初哪個鏈表最短了,很有可能整個結構已經面目全非了。所以我們不能根據這種動態的量來決定節點的放置位置,必須要根據靜態的量來決定。

這個靜態的量就是hash值,我們都知道hash算法的本質上是進行一個映射運算,將一個任意結構的值映射到一個整數上。我們的理想情況是不同的值映射的結果不同,相同的值映射的結果相同。也就是說一個變量和一個整數是一一對應的。但是由於我們的整數數量是有限的,而變量的取值是無窮的,那麼一定會有一些變量雖然它們並不相等但是它們映射之後的結果是一樣的。這種情況叫做hash碰撞

在hashmap當中我們並不需要理會hash碰撞,因爲我們並不追求不同的key能夠映射到不同的值。因爲我們只是要用這個hash值來決定這個節點應該存放在哪一條鏈表當中。只要hash函數確定了,只要值不變,計算得到的hash值也不會變。所以我們查詢的時候也可以遵循這個邏輯,找到key對應的hash值以及對應的鏈表。

在Python當中由於系統提供了hash函數,所以整個過程變得更加方便。我們只需要兩行代碼就可以找到key對應的鏈表。

hash_key = hash(key) % len(self.headers)
linked_list = self.headers[hash_key]

get、put實現

明白了hash函數的作用了之後,hashmap的問題就算是解決了大半。因爲剩下的就是一個在鏈表當中增刪改查的問題了,比如我們要通過key查找value的時候。當我們通過hash函數確定了是哪一個鏈表之後,剩下的就是遍歷這個鏈表找到這個值。

這個函數我們可以實現在LinkedList這個類當中,非常簡單,就是一個簡單的遍歷:

def get_by_key(self, key):
    cur = self.head.succ
    while cur != self.tail:
        if cur.key == key:
            return cur
        cur = cur.succ
    return None

鏈表的節點查詢邏輯有了之後,hashmap的查詢邏輯也就有了。因爲本質上只做了兩件事,一件事根據hash函數的值找到對應的鏈表,第二件事就是遍歷這個鏈表,找到這個節點。

我們也很容易實現:

def get(self, key):
    hash_key = self.get_hash_key(key)
    linked_list = self.headers[hash_key]
    node = linked_list.get_by_key(key)
    return node

get方法實現了之後,寫出put方法也一樣水到渠成,因爲put方法邏輯和get相反。我們把查找換成添加或者是修改即可:

def put(self, key, val):
    node = self.get(key)
    # 如果能找到,那麼只需要更新即可
    if node is not None:
        node.val = val
    else:
        # 否則我們在鏈表當中添加一個節點
        node = Node(key, val)
        linked_list.append(node)

複雜度的保障

get和put都實現了,整個hashmap是不是就實現完了?很顯然沒有,因爲還有一件很重要的事情我們沒有做,就是保證hashmap的複雜度

我們簡單分析一下就會發現,這樣實現的hashmap有一個重大的問題。就是由於hashmap一開始的鏈表的數組是定長的,不管這個數組多長,只要我們存儲的元素足夠多,那麼每一個鏈表當中分配到的元素也就會非常多。我們都知道鏈表的遍歷速度是,這樣我們還怎麼保證查詢的速度是常數級呢?

除此之外還有另外一個問題,就是hash值傾斜的問題。比如明明我們的鏈表有100個,但是我們的數據剛好hash值大部分對100取模之後都是0。於是大量的數據就會被存儲在0這個桶當中,導致其他桶沒什麼數據,就這一個桶爆滿。對於這種情況我們又怎麼避免呢?

其實不論是數據過多也好,還是分佈不均勻也罷,其實說的都是同一種情況。就是至少一個桶當中存儲的數據過多,導致效率降低。針對這種情況,hashmap當中設計了一種檢查機制,一旦某一個桶當中的元素超過某個閾值,那麼就會觸發reset。也就是把hashmap當中的鏈表數量增加一倍,並且把數據全部打亂重建。這個閾值是通過一個叫做load_factor的參數設置的,當某一個桶當中的元素大於load_factor * capacity的時候,就會觸發reset機制。

我們把reset的邏輯加進去,那麼put函數就變成了這樣:

def put(self, key, val):
    hash_key = self.get_hash_key(key)
    linked_list = self.headers[hash_key]
    # 如果超過閾值
    if linked_list.size >= self.load_factor * self.capacity:
        # 進行所有數據reset
        self.reset()
        # 對當前要加入的元素重新hash分桶
        hash_key = self.get_hash_key(key)
        linked_list = self.headers[hash_key]
        node = linked_list.get_by_key(key)
        if node is not None:
            node.val = val
        else:
            node = Node(key, val)
            linked_list.append(node)

reset的邏輯也很簡單,我們把數組的長度擴大一倍,然後把原本的數據一一讀取出來,重新hash分配到新的桶當中即可。

def reset(self):
    # 數組擴大一倍
    headers = [LinkedList() for _ in range(self.capacity * 2)]
    cap = self.capacity
    # capacity也擴大一倍
    self.capacity = self.capacity * 2
    for i in range(cap):
        linked_list = self.headers[i]
        nodes = linked_list.get_list()
        # 將原本的node一個一個填入新的鏈表當中
        for u in nodes:
            hash_key = self.get_hash_key(u.key)
            head = headers[hash_key]
            head.append(u)
    self.headers = headers

其實這裏的閾值就是我們的最大查詢時間,我們可以把它近似看成是一個比較大的常量,那麼put和get的效率就有保障了。因爲插入了大量數據或者是剛好遇到了hash不平均的情況我們就算是都解決了。

細節和昇華

如果你讀過JDK當中hashmap的源碼,你會發現hashmap的capacity也就是鏈表的數量是2的冪。這是爲什麼呢?

其實也很簡單,因爲按照我們剛纔的邏輯,當我們通過hash函數計算出了hash值之後,還需要將這個值對capacity進行取模。也就是hash(key) % capacity,這一點在剛纔的代碼當中也有體現。

這裏有一個小問題就是取模運算非常非常慢,在系統層面級比加減乘慢了數十倍。爲了優化和提升這個部分的性能,所以我們使用2的冪,這樣我們就可以用hash(key) & (capacity - 1)來代替hash(key) % capacity,因爲當capacity是2的冪時,這兩者計算是等價的。我們都知道位運算的計算速度是計算機當中所有運算最快的,這樣我們可以提升不少的計算效率。

最後聊一聊線程安全,hashmap是線程安全的嗎?答案很簡單,當然不是。因爲裏面沒有任何加鎖或者是互斥的限制,A線程在修改一個節點,B線程也可以同時在讀取同樣的節點。那麼很容易出現問題,尤其是還有reset這種時間比較長的操作。如果剛好在reset期間來了其他的查詢,那麼結果一定是查詢不到,但很有可能這個數據是存在的。所以hashmap不是線程安全的,不可以在併發場景當中使用。

最後,我們附上hashmap完整的實現代碼:

import random

class Node:
    def __init__(self, key, val, prev=None, succ=None):
        self.key = key
        self.val = val
        # 前驅
        self.prev = prev
        # 後繼
        self.succ = succ

    def __repr__(self):
        return str(self.val)


class LinkedList:
    def __init__(self):
        self.head = Node(None, 'header')
        self.tail = Node(None, 'tail')
        self.head.succ = self.tail
        self.tail.prev = self.head
        self.size = 0

    def append(self, node):
        # 將node節點添加在鏈表尾部
        prev = self.tail.prev
        node.prev = prev
        node.succ = prev.succ
        prev.succ = node
        node.succ.prev = node
        self.size += 1

    def delete(self, node):
        # 刪除節點
        prev = node.prev
        succ = node.succ
        succ.prev, prev.succ = prev, succ
        self.size -= 1

    def get_list(self):
        # 返回一個包含所有節點的list,方便上游遍歷
        ret = []
        cur = self.head.succ
        while cur != self.tail:
            ret.append(cur)
            cur = cur.succ
        return ret

    def get_by_key(self, key):
        cur = self.head.succ
        while cur != self.tail:
            if cur.key == key:
                return cur
            cur = cur.succ
        return None



class HashMap:
    def __init__(self, capacity=16, load_factor=5):
        self.capacity = capacity
        self.load_factor = load_factor
        self.headers = [LinkedList() for _ in range(capacity)]

    def get_hash_key(self, key):
        return hash(key) & (self.capacity - 1)

    def put(self, key, val):
        hash_key = self.get_hash_key(key)
        linked_list = self.headers[hash_key]
        if linked_list.size >= self.load_factor * self.capacity:
            self.reset()
            hash_key = self.get_hash_key(key)
            linked_list = self.headers[hash_key]
        node = linked_list.get_by_key(key)
        if node is not None:
            node.val = val
        else:
            node = Node(key, val)
            linked_list.append(node)

    def get(self, key):
        hash_key = self.get_hash_key(key)
        linked_list = self.headers[hash_key]
        node = linked_list.get_by_key(key)
        return node.val if node is not None else None

    def delete(self, key):
        node = self.get(key)
        if node is None:
            return False
        hash_key = self.get_hash_key(key)
        linked_list = self.headers[hash_key]
        linked_list.delete(node)
        return True

    def reset(self):
        headers = [LinkedList() for _ in range(self.capacity * 2)]
        cap = self.capacity
        self.capacity = self.capacity * 2
        for i in range(cap):
            linked_list = self.headers[i]
            nodes = linked_list.get_list()
            for u in nodes:
                hash_key = self.get_hash_key(u.key)
                head = headers[hash_key]
                head.append(u)
        self.headers = headers

今天的文章就到這裏,衷心祝願大家每天都有所收穫。如果還喜歡今天的內容的話,請來一個三連支持吧~(點贊、關注、轉發

原文鏈接,求個關注

本文使用 mdnice 排版

- END -

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章