python Union-Find
自分のライブラリ置き場です。
データ構造 UnionFind の python コードです。
こんなことができます。
- x を含む集合の根
find(x)
- x を含む集合と y を含む集合の合体
unite(x, y)
- x と y が同じ集合に含まれるか判定
same(x, y)
- x が含まれる集合の要素の個数
size(x)
- 集合の個数
set_count()
# UnionFind (0-indexed) # (n[i]<0)のとき, iは根, -n[i]がその集合のsize # (n[i]>=0)のとき, n[i]がiの親 # r[i]はiのに続く子の深さ # coは集合の個数 class UnionFind: def __init__(self, n): self.n = [-1]*n self.r = [0]*n self.co = n def find(self, x): if self.n[x] < 0: return x else: self.n[x] = self.find(self.n[x]) return self.n[x] def unite(self, x, y): x = self.find(x) y = self.find(y) if x == y: return if self.r[x] > self.r[y]: x, y = y, x self.n[y] += self.n[x] self.n[x] = y if self.r[x] == self.r[y]: self.r[y] += 1 self.co -= 1 def same(self, x, y): return self.find(x) == self.find(y) def size(self, x): return -self.n[self.find(x)] def set_count(self): return self.co
使用例です。UnionFind は最上部とかに貼り付けるだけなので省略しています。
N 個の頂点、Q 個のクエリ
t=0 なら頂点 (u, v) に辺をつなぐ
t=1 なら頂点 (u, v) が連結か判定
n, q = map(int, input().split()) uf = UnionFind(n) for _ in range(q): t, u, v = map(int, input().split()) if t == 0: uf.unite(u, v) if t == 1: if uf.same(u, v): print(1) else: print(0)
N 個の頂点、M 本の辺
各頂点について、その頂点と連結している頂点の個数を求める
n, m = map(int, input().split()) uf = UnionFind(n) for _ in range(m): u, v= map(int, input().split()) uf.unite(u, v) for i in range(n): print(uf.size(i))
N 個の頂点、M 本の辺
辺を最低何本追加すれば全ての頂点を連結にできるか
n, m = map(int, input().split()) uf = UnionFind(n) for _ in range(m): u, v= map(int, input().split()) uf.unite(u, v) print(uf.set_count()-1)
N 個の頂点、Q 個のクエリ
t=0 のとき頂点 (u, v) に赤色の辺をつなぐ
t=1 のとき頂点 (u, v) に青色の辺をつなぐ
t=2 のとき頂点 (u, v) が赤色の辺, 青色の辺どちらについても連結か判定
n, q = map(int, input().split()) red_uf = UnionFind(n) blue_uf = UnionFind(n) for _ in range(q): t, u, v = map(int, input().split()) if t == 0: red_uf.unite(u, v) if t == 1: blue_uf.unite(u, v) if t == 2: if red_uf.same(u, v) and blue_uf.same(u, v): print(1) else: print(0)
UnionFind に少し書き加えることで、答えられるクエリが増えたりします。
実装を理解しておくと応用できます。
N 個の頂点、Q 個のクエリ
t=0 のとき、頂点 (u, v) に辺をつなぐ
t=1 のとき、頂点 u から移動することで到達できる辺の個数を求める
# edgeco[find(x)]...xから移動することで到達できる辺の個数 class UnionFind: def __init__(self, n): self.n = [-1]*n self.r = [0]*n self.co = n self.edgeco = [0]*n #追加 def find(self, x): if self.n[x] < 0: return x else: self.n[x] = self.find(self.n[x]) return self.n[x] def unite(self, x, y): x = self.find(x) y = self.find(y) if x == y: self.edgeco[x] += 1 #追加 return if self.r[x] > self.r[y]: x, y = y, x self.edgeco[y] += self.edgeco[x]+1 #追加 self.n[y] += self.n[x] self.n[x] = y if self.r[x] == self.r[y]: self.r[y] += 1 self.co -= 1 def same(self, x, y): return self.find(x) == self.find(y) def size(self, x): return -self.n[self.find(x)] def set_count(self): return self.co # 追加 def edge_count(self, x): return self.edgeco[self.find(x)] n, q = map(int, input().split()) uf = UnionFind(n) for _ in range(q): t, u, v = map(int, input().split()) if t == 0: uf.unite(u, v) if t == 1: print(uf.edge_count(u))
N 個の頂点、各頂点には報酬 a[i] がある
Q 個のクエリ
t=0 のとき頂点 (u, v) に辺をつなぐ
t=1 のとき「頂点を 1 つ選び、その頂点と連結な頂点の報酬を全て得る」の報酬の最大値を求める
# a[find(x)]...頂点xと連結な頂点の報酬の総和 # maxa...max(a)を常に保持 class UnionFind: def __init__(self, n, a): self.n = [-1]*n self.r = [0]*n self.co = n self.a = a #追加 self.maxa = max(a) #追加 def find(self, x): if self.n[x] < 0: return x else: self.n[x] = self.find(self.n[x]) return self.n[x] def unite(self, x, y): x = self.find(x) y = self.find(y) if x == y: return if self.r[x] > self.r[y]: x, y = y, x self.a[y] += self.a[x] #追加 self.maxa = max(self.maxa, self.a[y]) #追加 self.n[y] += self.n[x] self.n[x] = y if self.r[x] == self.r[y]: self.r[y] += 1 self.co -= 1 def same(self, x, y): return self.find(x) == self.find(y) def size(self, x): return -self.n[self.find(x)] def set_count(self): return self.co # 追加 def max_a(self): return self.maxa n, q = map(int, input().split()) a = list(map(int, input().split())) uf = UnionFind(n, a) for _ in range(q): t, u, v = map(int, input().split()) if t == 0: uf.unite(u, v) if t == 1: print(uf.max_a())