2414 words
12 minutes
数据结构
首次发布: 2025-04-09
... 次访问

这篇笔记整理一些常见且“好用的特殊数据结构”,并配上原理讲解 + 可读性优先的 Python 实现

说明:二叉树与图在其它笔记里已单独展开,这里不再重复(但会提到“堆”这种常用于优先队列的结构,重点放在数组实现与接口)。

目录#

  • HashSet(链式哈希)
  • HashMap(链式哈希)
  • 栈 / 队列 / 双端队列(环形缓冲区)
  • 堆(优先队列,数组实现)
  • 并查集(Disjoint Set Union)
  • Trie(前缀树)
  • Bloom Filter(布隆过滤器)
  • LRU Cache(哈希表 + 双向链表)

1. HashSet:哈希表实现的集合(链式拉链法)#

1.1 原理与复杂度#

哈希表的核心是把元素映射到桶(bucket)下标:

extindex=h(x)modMext{index} = h(x)\bmod M

其中 MM 是桶数组大小。理想情况下,每个桶里元素数量很少,因此查找/插入/删除平均时间复杂度为 O(1)O(1)

现实中会发生哈希冲突(不同元素映射到同一个桶)。常见解决冲突方式:

  • 链式拉链(separate chaining):每个桶挂一条链表/动态数组。
  • 开放寻址(open addressing):冲突时在桶数组里继续探测(线性探测、二次探测、双重哈希)。

本文用链式拉链法:实现简单、删除方便。

1.2 结构示意#

+---------------------+
| 哈希表(桶数组)     |
+---------------------+
| 索引0 | → [A] → null
+-------+
| 索引1 | → null
+-------+
| 索引2 | → [B] → [C] → [D] → null
+-------+
| 索引3 | → null
+-------+

1.3 代码实现(HashSet)#

class _HashSetNode:
    def __init__(self, value, next_node=None):
        self.value = value
        self.next = next_node


class HashSet:
    """链式哈希实现的 Set:平均 O(1),最坏 O(n)。"""

    def __init__(self, items=None, initial_capacity=8, load_factor=0.75):
        self._capacity = max(2, int(initial_capacity))
        self._buckets = [None] * self._capacity
        self._size = 0
        self._load_factor = float(load_factor)

        if items is not None:
            for item in items:
                self.add(item)

    def _index(self, item):
        # Python 的 hash 可能为负;这里转成非负更直观
        return (hash(item) & 0x7FFFFFFF) % self._capacity

    def _resize(self, new_capacity):
        old_buckets = self._buckets
        self._capacity = int(new_capacity)
        self._buckets = [None] * self._capacity
        self._size = 0

        for head in old_buckets:
            node = head
            while node:
                self.add(node.value)
                node = node.next

    def add(self, item):
        if self._size + 1 > self._capacity * self._load_factor:
            self._resize(self._capacity * 2)

        idx = self._index(item)
        node = self._buckets[idx]
        while node:
            if node.value == item:
                return
            node = node.next

        self._buckets[idx] = _HashSetNode(item, self._buckets[idx])
        self._size += 1

    def remove(self, item):
        idx = self._index(item)
        prev = None
        curr = self._buckets[idx]
        while curr:
            if curr.value == item:
                if prev is None:
                    self._buckets[idx] = curr.next
                else:
                    prev.next = curr.next
                self._size -= 1
                return
            prev, curr = curr, curr.next
        raise KeyError(f"{item} not in set")

    def discard(self, item):
        try:
            self.remove(item)
        except KeyError:
            return

    def __contains__(self, item):
        idx = self._index(item)
        node = self._buckets[idx]
        while node:
            if node.value == item:
                return True
            node = node.next
        return False

    def __len__(self):
        return self._size

    def __iter__(self):
        for head in self._buckets:
            node = head
            while node:
                yield node.value
                node = node.next

    def __repr__(self):
        return f"HashSet({list(self)!r})"

    def union(self, other):
        result = HashSet(self)
        for x in other:
            result.add(x)
        return result

    def intersection(self, other):
        result = HashSet()
        if len(self) <= len(other):
            for x in self:
                if x in other:
                    result.add(x)
        else:
            for x in other:
                if x in self:
                    result.add(x)
        return result

    def difference(self, other):
        result = HashSet()
        for x in self:
            if x not in other:
                result.add(x)
        return result

2. HashMap:哈希表实现的字典(链式拉链法)#

Set 只存 value;Map/Dict 则存 key→value 映射。实现上只要节点里存 (key, value)

class _HashMapNode:
    def __init__(self, key, value, next_node=None):
        self.key = key
        self.value = value
        self.next = next_node


class HashMap:
    """链式哈希实现的 Map。接口风格接近 dict。"""

    def __init__(self, items=None, initial_capacity=8, load_factor=0.75):
        self._capacity = max(2, int(initial_capacity))
        self._buckets = [None] * self._capacity
        self._size = 0
        self._load_factor = float(load_factor)

        if items is not None:
            for k, v in items:
                self[k] = v

    def _index(self, key):
        return (hash(key) & 0x7FFFFFFF) % self._capacity

    def _resize(self, new_capacity):
        old = self._buckets
        self._capacity = int(new_capacity)
        self._buckets = [None] * self._capacity
        self._size = 0

        for head in old:
            node = head
            while node:
                self[node.key] = node.value
                node = node.next

    def __setitem__(self, key, value):
        if self._size + 1 > self._capacity * self._load_factor:
            self._resize(self._capacity * 2)

        idx = self._index(key)
        node = self._buckets[idx]
        while node:
            if node.key == key:
                node.value = value
                return
            node = node.next

        self._buckets[idx] = _HashMapNode(key, value, self._buckets[idx])
        self._size += 1

    def __getitem__(self, key):
        idx = self._index(key)
        node = self._buckets[idx]
        while node:
            if node.key == key:
                return node.value
            node = node.next
        raise KeyError(key)

    def get(self, key, default=None):
        try:
            return self[key]
        except KeyError:
            return default

    def __contains__(self, key):
        idx = self._index(key)
        node = self._buckets[idx]
        while node:
            if node.key == key:
                return True
            node = node.next
        return False

    def __delitem__(self, key):
        idx = self._index(key)
        prev = None
        curr = self._buckets[idx]
        while curr:
            if curr.key == key:
                if prev is None:
                    self._buckets[idx] = curr.next
                else:
                    prev.next = curr.next
                self._size -= 1
                return
            prev, curr = curr, curr.next
        raise KeyError(key)

    def __len__(self):
        return self._size

    def keys(self):
        for head in self._buckets:
            node = head
            while node:
                yield node.key
                node = node.next

    def values(self):
        for head in self._buckets:
            node = head
            while node:
                yield node.value
                node = node.next

    def items(self):
        for head in self._buckets:
            node = head
            while node:
                yield node.key, node.value
                node = node.next

    def __iter__(self):
        return self.keys()

    def __repr__(self):
        return f"HashMap({dict(self.items())!r})"

3. 栈 / 队列 / 双端队列:环形缓冲区(Ring Buffer)#

3.1 为什么要环形缓冲区#

Python 的 list.append/pop() 在尾部是均摊 O(1)O(1),但如果你用 pop(0) 去实现队列,会退化为 O(n)O(n)(整体搬移)。

环形缓冲区用一个数组 + 两个指针在“圆环”上移动:

  • 入队:写到 tailtail = (tail + 1) % cap
  • 出队:读 headhead = (head + 1) % cap

这样队列/双端队列可以稳定 O(1)O(1)

3.2 Deque(双端队列)实现#

class Deque:
    def __init__(self, initial_capacity=8):
        self._cap = max(2, int(initial_capacity))
        self._data = [None] * self._cap
        self._head = 0
        self._tail = 0
        self._size = 0

    def __len__(self):
        return self._size

    def _grow(self):
        new_cap = self._cap * 2
        new_data = [None] * new_cap
        for i in range(self._size):
            new_data[i] = self._data[(self._head + i) % self._cap]
        self._data = new_data
        self._cap = new_cap
        self._head = 0
        self._tail = self._size

    def append(self, x):
        if self._size == self._cap:
            self._grow()
        self._data[self._tail] = x
        self._tail = (self._tail + 1) % self._cap
        self._size += 1

    def appendleft(self, x):
        if self._size == self._cap:
            self._grow()
        self._head = (self._head - 1) % self._cap
        self._data[self._head] = x
        self._size += 1

    def pop(self):
        if self._size == 0:
            raise IndexError("pop from empty deque")
        self._tail = (self._tail - 1) % self._cap
        x = self._data[self._tail]
        self._data[self._tail] = None
        self._size -= 1
        return x

    def popleft(self):
        if self._size == 0:
            raise IndexError("pop from empty deque")
        x = self._data[self._head]
        self._data[self._head] = None
        self._head = (self._head + 1) % self._cap
        self._size -= 1
        return x

    def peekleft(self):
        if self._size == 0:
            raise IndexError("peek from empty deque")
        return self._data[self._head]

    def peek(self):
        if self._size == 0:
            raise IndexError("peek from empty deque")
        return self._data[(self._tail - 1) % self._cap]

基于 Deque 可以轻松得到:

  • 栈:只用 append/pop
  • 队列:只用 append/popleft

4. 堆(优先队列):数组实现#

优先队列支持:

  • push(x):加入元素
  • pop():弹出最小(或最大)元素

常用实现是二叉堆(用数组存储,父子关系由下标计算,不需要显式树节点):

  • parent = (i - 1) // 2
  • left = 2i + 1, right = 2i + 2
class MinHeap:
    def __init__(self, items=None):
        self._a = []
        if items is not None:
            for x in items:
                self.push(x)

    def __len__(self):
        return len(self._a)

    def peek(self):
        if not self._a:
            raise IndexError("peek from empty heap")
        return self._a[0]

    def push(self, x):
        a = self._a
        a.append(x)
        i = len(a) - 1
        while i > 0:
            p = (i - 1) // 2
            if a[p] <= a[i]:
                break
            a[p], a[i] = a[i], a[p]
            i = p

    def pop(self):
        a = self._a
        if not a:
            raise IndexError("pop from empty heap")
        if len(a) == 1:
            return a.pop()
        top = a[0]
        a[0] = a.pop()
        self._sift_down(0)
        return top

    def _sift_down(self, i):
        a = self._a
        n = len(a)
        while True:
            l = 2 * i + 1
            r = l + 1
            smallest = i
            if l < n and a[l] < a[smallest]:
                smallest = l
            if r < n and a[r] < a[smallest]:
                smallest = r
            if smallest == i:
                break
            a[i], a[smallest] = a[smallest], a[i]
            i = smallest

复杂度:push/pop 均为 O(logn)O(\log n)peekO(1)O(1)


5. 并查集(Disjoint Set Union / Union-Find)#

并查集用来维护“若干不相交集合”的动态合并与查询:

  • find(x):找代表元(根)
  • union(x,y):合并两个集合

关键优化:

  • 路径压缩:find 时把路径上的点直接挂到根上
  • 按秩/按大小合并:小树挂大树
class DisjointSetUnion:
    def __init__(self, n):
        self.parent = list(range(n))
        self.size = [1] * n

    def find(self, x):
        while self.parent[x] != x:
            self.parent[x] = self.parent[self.parent[x]]
            x = self.parent[x]
        return x

    def union(self, a, b):
        ra, rb = self.find(a), self.find(b)
        if ra == rb:
            return False
        if self.size[ra] < self.size[rb]:
            ra, rb = rb, ra
        self.parent[rb] = ra
        self.size[ra] += self.size[rb]
        return True

    def same(self, a, b):
        return self.find(a) == self.find(b)

均摊复杂度接近 O(1)O(1)(更精确为 α(n)\alpha(n))。


6. Trie(前缀树):字符串前缀查询#

Trie 适合做:

  • 精确查词 search(word)
  • 前缀查询 starts_with(prefix)

它把字符串按字符逐层展开,查询时间与字符串长度 LL 成正比:O(L)O(L)

class TrieNode:
    def __init__(self):
        self.children = {}
        self.is_end = False


class Trie:
    def __init__(self):
        self.root = TrieNode()

    def insert(self, word):
        node = self.root
        for ch in word:
            node = node.children.setdefault(ch, TrieNode())
        node.is_end = True

    def search(self, word):
        node = self.root
        for ch in word:
            if ch not in node.children:
                return False
            node = node.children[ch]
        return node.is_end

    def starts_with(self, prefix):
        node = self.root
        for ch in prefix:
            if ch not in node.children:
                return False
            node = node.children[ch]
        return True

7. Bloom Filter:快速判“可能存在/一定不存在”#

布隆过滤器用一个 bitset + kk 个哈希函数表示集合。

  • 查询:如果任意一位为 0,则一定不存在
  • 如果全部为 1,则可能存在(有假阳性,false positive)

常见工程做法是用 double hashing 从两个 hash 派生出 kk 个:

hi(x)=h1(x)+ih2(x)h_i(x)=h_1(x)+i\cdot h_2(x)
class BloomFilter:
    def __init__(self, m_bits, k_hashes=4):
        self.m = int(m_bits)
        self.k = int(k_hashes)
        self.bits = bytearray((self.m + 7) // 8)

    def _hashes(self, item):
        h1 = hash((item, 0)) & 0x7FFFFFFF
        h2 = hash((item, 1)) & 0x7FFFFFFF
        for i in range(self.k):
            yield (h1 + i * h2) % self.m

    def add(self, item):
        for h in self._hashes(item):
            self.bits[h >> 3] |= 1 << (h & 7)

    def __contains__(self, item):
        for h in self._hashes(item):
            if (self.bits[h >> 3] >> (h & 7)) & 1 == 0:
                return False
        return True

应用场景:大规模去重、缓存穿透防护、快速过滤(先用 Bloom 过滤,再去数据库/HashSet 做精确查询)。


8. LRU Cache:哈希表 + 双向链表#

LRU(Least Recently Used)缓存要求:

  • get(key):如果命中,把该项变“最近使用”
  • put(key,value):如果超容量,淘汰“最久未使用”

要做到 O(1)O(1)

  • 用哈希表:key → 节点(快速定位)
  • 用双向链表:维护使用顺序(快速移动与淘汰)
class _LRUNode:
    def __init__(self, key=None, value=None):
        self.key = key
        self.value = value
        self.prev = None
        self.next = None


class LRUCache:
    def __init__(self, capacity):
        self.capacity = int(capacity)
        self.map = {}

        # 哨兵节点:head <-> ... <-> tail
        self.head = _LRUNode()
        self.tail = _LRUNode()
        self.head.next = self.tail
        self.tail.prev = self.head

    def _remove(self, node):
        p, n = node.prev, node.next
        p.next = n
        n.prev = p

    def _add_to_front(self, node):
        node.prev = self.head
        node.next = self.head.next
        self.head.next.prev = node
        self.head.next = node

    def get(self, key, default=None):
        if key not in self.map:
            return default
        node = self.map[key]
        self._remove(node)
        self._add_to_front(node)
        return node.value

    def put(self, key, value):
        if key in self.map:
            node = self.map[key]
            node.value = value
            self._remove(node)
            self._add_to_front(node)
            return

        node = _LRUNode(key, value)
        self.map[key] = node
        self._add_to_front(node)

        if len(self.map) > self.capacity:
            # 淘汰 tail.prev
            lru = self.tail.prev
            self._remove(lru)
            del self.map[lru.key]

小结:如何选数据结构#

  • 需要 快速查存在性/去重:HashSet
  • 需要 键值映射:HashMap
  • 需要 FIFO / 双端操作:环形 Deque
  • 需要 随时取最小/最大:堆(优先队列)
  • 需要 动态连通性/集合合并:并查集
  • 需要 前缀检索/词典:Trie
  • 需要 快速过滤、允许假阳性:Bloom Filter
  • 需要 缓存淘汰策略:LRU Cache

Comments Section