セグメント木

プログラミングコンテストチャレンジブック

プログラミングコンテストチャレンジブック

で、セグメント木というデータ構造を勉強したので、メモする。

問題

最初に、数列 a[i] i = 0, ..., n - 1 が与えられる。これに対して、区間[l, r)の最小値を求めよというクエリが繰り返し投げられるので、適当なデータ構造を考えなさい、という問題に対する1つの答えがセグメント木である。以下、簡単のため、n = 2^k とする。

構造

セグメント木は、深さ k の2分木である。従って、n 個の葉を持つ。セグメント木は、葉からボトムアップに、以下のように構成する。

  • 左から i 番目の葉に、値a[i]を格納する。
  • 中間ノードには、2つの子ノードが保持している値の小さい方を格納する。

こうすると、深さ i の階層で左から j 番目のノードには、区間[j 2^{k-i}, (j + 1) 2^{k - i})における最小値が格納される。但し、i, j は 0 始まり。つまり、根は深さ0。各階層で左端のノードは、左から0番目。

クエリ処理

関数getAnsを以下で定義する。

# j : ノード
# l, r : 最小値を求めたい区間の左端、右端
getAns(j, l, r):
	if [l, r) と ノード j の担当区間が重ならない
		return infinity
	else if ノード j の担当区間が[l, r)に含まれる
		return ノードの値
	else
		return min(getAns(jの左の子, l, r), getAns(jの右の子, l, r))

区間[l, r)の最小値は、getAns(ルートノード, l, r)で計算される。この処理の気持ちは、以下のような感じである。
まず、一般論として、区間Aが複数の区間B_iの和集合として表される時、

[区間Aの最小値] = min(区間B_iの最小値)
が成り立つことに注意する。
今、区間[l, r)の最小値を求めたいとする。セグメント木の各ノードには、「切りの良い」区間の最小値が事前に計算されて格納されている。上の事実から、計算済みの区間の和集合として[l, r)が表現出来れば、[l, r)の最小値は計算済みのデータだけから計算出来る。この時、[l, r)を表現するために利用される区間の数はなるべく少ない方が良い。そのためには、大きい区間から順番に、[l, r)の最小値を求めるのにその区間の計算結果を利用出来るかどうかを調べ、利用出来る奴に出会ったら利用するというルールを実行すれば良い。

考察と一般化

セグメント木が上手くいく根本的な理由の一つは、min(a, b)という演算の結合性(つまり、min(min(a, b), c) = min(a, min(b, c))が成り立つこと)である。結合性が無いと、予め部分区間について演算しておいた結果を再利用することが出来ないので、このアルゴリズムは使えない。また、min以外の演算を使う場合には、getAnsでinfinityを返している所を、その演算の単位元を返すようにすれば良い。
ただ、「結合性、結合性、。。。」とだけ思っているとちょっと思い付かないような例が、上の本の例題に乗っているので、興味のある人は、ご一読を。

実装例

2分木は配列で表現出来ることに注意。

/*! g++ main.cpp -Wall -g
 */

#include <iostream>
#include <vector>
#include <algorithm>
#include <climits>
#include <sstream>
#include <iterator>

using namespace std;

class SegmentTree{
  vector<int> tree;
  vector<int> array;
  vector<int> left, right;
  int n;

  void makeTree(int i){
    if(i >= n - 1){
      tree[i] = array[i - n + 1];
      left[i] = i - n + 1;
      right[i] = left[i] + 1;
    }
    else{
      makeTree(2 * i + 1);
      makeTree(2 * i + 2);
      tree[i] = min(tree[2 * i + 1], tree[2 * i + 2]);
      left[i] = left[2 * i + 1];
      right[i] = right[2 * i + 2];
    }
  }

  int getMin(int i, int l, int r){
    int ret;
    if(l >= right[i] || r <= left[i]) ret = INT_MAX;
    else if(l <= left[i] && right[i] <= r) ret = tree[i];
    else ret = min(getMin(2 * i + 1, l, r), getMin(2 * i + 2, l, r));
    return ret;
  }

public:
  SegmentTree(const vector<int>& a){
    int m = 1;
    while(m < (int)a.size()){
      m *= 2;
    }
    n = (int)a.size();

    array.resize(m, INT_MAX);
    copy(a.begin(), a.end(), array.begin());

    tree.resize(2 * m - 1);
    left.resize(2 * m - 1);
    right.resize(2 * m - 1);
    makeTree(0);
  }

  int getMin(int l, int r){
    return getMin(0, l, r);
  }

  void debug(){
    copy(tree.begin(), tree.end(), ostream_iterator<int>(cout, " ")); cout << endl;
    copy(array.begin(), array.end(), ostream_iterator<int>(cout, " ")); cout << endl;
  }
};

int main(int argc, char* argv[]){
  stringstream sst(argv[1]);
  int buf;
  vector<int> a;
  while(sst >> buf){
    a.push_back(buf);
  }
  SegmentTree st(a);
  cout << st.getMin(atoi(argv[2]), atoi(argv[3])) << endl;
  st.debug();

  return 0;
}