python SegmentTree
自分のライブラリ置き場です。
データ構造 SegmentTree の python コードです。
抽象化、非再帰です。非2べきあります。
汎用性の鬼です。できることが多すぎるので、慣れが必要です。私は慣れていません。
10/21 関数追加しました。max_right の逆verです。min_left(r, g)
class SegmentTree: def __init__(self, n, p, unit, f): self.n = n self.num = 2**((n-1).bit_length()) self.seg = [unit]*(self.num*2) for i in range(n): self.seg[self.num+i] = p[i] for i in range(self.num-1, 0, -1): self.seg[i] = f(self.seg[i << 1], self.seg[(i << 1)+1]) self.unit = unit self.f = f def update(self, i, x): i += self.num self.seg[i] = x while i: i >>= 1 self.seg[i] = self.f(self.seg[i << 1], self.seg[(i << 1)+1]) def query(self, l, r): ansl = ansr = self.unit l += self.num r += self.num-1 if l == r: return self.seg[l] while l < r: if l & 1: ansl = self.f(ansl, self.seg[l]) l += 1 if ~r & 1: ansr = self.f(self.seg[r], ansr) r -= 1 l >>= 1 r >>= 1 if l == r: ansl = self.f(ansl, self.seg[l]) return self.f(ansl, ansr) def max_right(self, l, g): l += self.num ll = l // (l & -l) ans = self.unit while g(self.f(ans, self.seg[ll])): ans = self.f(ans, self.seg[ll]) ll += 1 while ~ll & 1: ll >>= 1 if ll == 1: return self.n while ll < self.num: ll <<= 1 if g(self.f(ans, self.seg[ll])): ans = self.f(ans, self.seg[ll]) ll += 1 return ll-self.num def min_left(self, r, g): r += self.num rr = max(r // (~r & -~r), 1) ans = self.unit while g(self.f(self.seg[rr], ans)): ans = self.f(self.seg[rr], ans) rr -= 1 while rr & 1: rr >>= 1 if rr == 0: return -1 while rr < self.num: rr <<= 1 if g(self.f(self.seg[rr+1], ans)): ans = self.f(self.seg[rr+1], ans) else: rr += 1 return rr - self.num
セグ木内で使う関数 は を満たす関数である必要があります。結合律っていいます。
また、 を満たす が必要です。単位元っていいます。
例えば区間の総和を持ちたいときは になります。
区間の最大値の場合は になります。
gcd の場合は になります。
Twitter でつよい人とかが「モノイドなのでセグ木に乗る」と言っているのは「この2つの条件を満たしている」ことを指しています。
関数の説明を兼ねた使用例です。本体は最上部とかに貼り付けるだけなので省略しています。
# 区間の最大値を持つようにする def f(x, y): return max(x, y) n = 5 a = [1, 2, 3, 2, 1] # 初期化 # 引数は ( 配列の長さ, 配列の初期状態, 単位元, 関数 ) seg = SegmentTree(n, a, 0, f) # seg.query(l, r) ---> a[l, r) の最大値 print(seg.query(0, 2)) # ---> 2 print(seg.query(2, 5)) # ---> 3 # seg.update(i, x) ---> a[i] = x seg.update(1, 5) seg.update(2, -10) seg.update(4, 100) print(seg.query(0, 3)) # ---> 5 print(seg.query(0, 5)) # ---> 100
# 区間の総和を持つようにする def f(x, y): return x+y n = 5 a = [1, 2, 3, 2, 1] seg = SegmentTree(n, a, 0, f) # seg.query(l, r) ---> a[l, r) の総和 print(seg.query(0, 2)) # ---> 3 print(seg.query(2, 5)) # ---> 6 seg.update(1, 5) seg.update(2, -10) seg.update(4, 100) print(seg.query(0, 3)) # ---> -4 print(seg.query(0, 5)) # ---> 98
def f(x, y): return max(x, y) n = 8 a = [1, 1, 1, 2, 2, 2, 2, 3] seg = SegmentTree(n, a, 0, f) # seg.max_right(l, g) ---> g(seg.query(l, r)) を満たす最大の r # lambda式を使っています print(seg.max_right(0, lambda x: x<2)) # ---> 3 print(seg.max_right(0, lambda x: x<3)) # ---> 7 print(seg.max_right(0, lambda x: x<10)) # ---> 8 # もちろんこう書いても OK def g(x): return x<2 print(seg.max_right(0, g)) # ---> 3
def f(x, y): return x+y n = 5 a = [1, 1, 1, 1, 1] seg = SegmentTree(n, a, 0, f) print(seg.max_right(0, lambda x: x<2)) # ---> 1 print(seg.max_right(2, lambda x: x<2)) # ---> 3 print(seg.max_right(1, lambda x: x<3)) # ---> 3 print(seg.max_right(4, lambda x: x<3)) # ---> 5
提出例です
Static RMQ
def f(x, y): return min(x, y) n, q = map(int, input().split()) a = list(map(int, input().split())) seg = SegmentTree(n, a, 10**9, f) for _ in range(q): l, r = map(int, input().split()) print(seg.query(l, r))
def f(x, y): return max(x, y) n, q = map(int, input().split()) p = list(map(int, input().split())) seg = SegmentTree(n, p, 0, f) for _ in range(q): t, a, b = map(int, input().split()) a -= 1 b -= 1 if t == 1: seg.update(a, b+1) elif t == 2: print(seg.query(a, b+1)) else: print(seg.max_right(a, lambda x: x < b+1)+1)
def gcd(a, b): while b: a, b = b, a % b return a n = int(input()) p = list(map(int, input().split())) seg = SegmentTree(n, p, 0, gcd) ans = 0 for i in range(n): ans = max(ans, gcd(seg.query(0, i), seg.query(i+1, n))) print(ans)
Point Set Range Composite
ちゃんと実装しているので、 が成り立たなくてもちゃんと動きます
def func(x, y): ax, bx = x ay, by = y return (ax*ay % mod, (bx*ay+by) % mod) mod = 998244353 n, q = map(int, input().split()) f = [tuple(map(int, input().split())) for _ in range(n)] seg = SegmentTree(n, f, (1, 0), func) for _ in range(q): flag, pl, cr, dx = map(int, input().split()) if flag == 0: seg.update(pl, (cr, dx)) else: a, b = seg.query(pl, cr) print((a*dx+b) % mod)
おまけの 非2べき非再帰セグ木 です。
max_right(l, g)
は一応作ってありますが、2べきに比べて (実装が難しくて) 定数倍が重めなのでオススメしません。
セグ木内のにぶたんをするなら2べき、しないなら非2べき、みたいに私は使い分けています。
(メモリ消費や実行時間を正確に計算したわけではないです。そのうちやります)
# SegmentTree class SegmentTree: def __init__(self, n, p, unit, f): self.n = n self.seg = p*2 for i in range(self.n-1, 0, -1): self.seg[i] = f(self.seg[i << 1], self.seg[(i << 1)+1]) self.unit = unit self.f = f def update(self, i, x): i += self.n self.seg[i] = x while i: i >>= 1 self.seg[i] = self.f(self.seg[i << 1], self.seg[(i << 1)+1]) def query(self, l, r): l += self.n r += self.n-1 ansl = ansr = self.unit if l == r: return self.seg[l] while l < r: if l & 1: ansl = self.f(ansl, self.seg[l]) l += 1 if ~r & 1: ansr = self.f(self.seg[r], ansr) r -= 1 l >>= 1 r >>= 1 if l == r: ansl = self.f(ansl, self.seg[l]) return self.f(ansl, ansr) def max_right(self, l, g): l += self.n r = self.n*2-1 ll = [] rr = [] ansl = ansr = self.unit while l < r: if l & 1: ll.append(l) ansl = self.f(ansl, self.seg[l]) l += 1 if ~r & 1: rr.append(r) ansr = self.f(self.seg[r], ansr) r -= 1 l >>= 1 r >>= 1 if l == r: ll.append(l) ansl = self.f(ansl, self.seg[l]) if g(self.f(ansl, ansr)): return self.n if g(ansl): m = rr[::-1] ans = ansl else: m = ll ans = self.unit for i in m: if g(self.f(ans, self.seg[i])): ans = self.f(ans, self.seg[i]) else: l = i break while l < self.n: l <<= 1 if g(self.f(ans, self.seg[l])): ans = self.f(ans, self.seg[l]) l += 1 return l-self.n