union find 習作

union findという便利なアルゴリズムをこの間のtopcoder養成講座で知った。簡単そうなので、練習。ここでの問題は、縦横nマスの碁盤目に文字が入っている(これは長さnの文字列のサイズnのリストとして与えられる)時に、単一の文字から成る連結領域で最大のものを*で埋めろというもの。例えば、

111
121
212

に対して

***
*2*
212

を返すようなプログラムを書けというもの。領域は、縦横に同じ文字があれば伸ばせるけど、斜めには伸ばせない。なので、上の例で3段目の真ん中の1は「最大の連結領域」には属さない、という感じのルール。

/*! g++ main.cpp -g -Wall
 */

#include <vector>
#include <iostream>
#include <algorithm>
#include <iterator>
#include <fstream>
#include <string>

using namespace std;

class Doit{
  class UnionFind{
    vector<int> data;
  public:
    UnionFind(int n) : data(n, -1){
    }
    int root(int x){
      if(data[x] >= 0) return data[x] = root(data[x]);
      else return x;
    }
    void merge(int x, int y){
      int rx = root(x);
      int ry = root(y);
      if(rx != ry){
    data[rx] += data[ry];
    data[ry] = rx;
      }
    }
    int getMaxGroup(){
      return min_element(data.begin(), data.end()) - data.begin();
    }
    bool find(int x, int y){
      return root(x) == root(y);
    }
  };

  const int n;
  UnionFind uf;
  vector<string> field;
  vector<int> dx, dy;
public:
  Doit(vector<string> s) :
    n(s.size()),
    uf(n * n),
    field(s),
    dx(2),
    dy(2)
  {
    dx[0] = 0; dx[1] = 1;
    dy[0] = 1; dy[1] = 0;
  }

  void calc(){
    copy(field.begin(), field.end(), ostream_iterator<string>(cout, "\n"));
    cout << endl;
    for(int i = 0; i < n; i++){
      for(int j = 0; j < n; j++){
    for(int k = 0; k < 2; k++){
      if(check(i, j, i + dx[k], j + dy[k])){
        uf.merge(i * n + j, (i + dx[k]) * n + j + dy[k]);
      }
    }
      }
    }
    int me = uf.getMaxGroup();
    for(int i = 0; i < n; i++){
      for(int j = 0; j < n; j++){
    if(uf.find(i * n + j, me)) field[i][j] = '*';
      }
    }
    copy(field.begin(), field.end(), ostream_iterator<string>(cout, "\n"));
  }

  bool check(int i, int j, int x, int y){
    return
      x < n &&
      y < n &&
      field[i][j] == field[x][y];
  }
};

int main(int argc, char* argv[]){
  ifstream ifs(argv[1]);
  string buf;
  vector<string> a;
  while(!(ifs >> buf).eof()) a.push_back(buf);
  Doit doit(a);
  doit.calc();
}
  • union find は、排反な集合族を管理するためのアルゴリズム・データ構造。
    • 2つの集合のマージ・与えられた2要素が同じ集合に属するかという問い合わせ の2つの操作を持つ。
    • 1つの集合に属する要素を1つの木として持つことで、後者の同一集合への帰属性チェックは、同じルートを持つかどうかでチェック出来る。
    • さらに、ある要素のルートを探すために親の方向へ辿る時、その要素をルートの直接の子となるように辺を付け替えることで、木の高さを低く保つ。
      • ここがキモ。キモなんだけど、実装は再帰を使って綺麗に書ける。賢いこと考えるねぇ。
    • 2つの集合のマージは、一方のルートを他方のルートの直接の子とするのみ。
    • 木といいつつ、実際は配列 data。data[i]は、iの親要素のインデックス。ただし、iがルートならば自分の属する集合の要素数
    • 要素の数でなくて要素の数の-1倍を持つことで、自分がルートでない場合の親へのポインタと兼用になっている。
    • なので、-1での初期化の意味は、1要素から成る集合が要素数だけあるという状態。
      • ルートの場合に要素数を入れるというのはunion findの本質とは関係無い。

グラフの中の頂点が隣接しているか?とか、あるマスとその隣が同じ文字か?みたいな局所的な情報からグラフの連結成分とか升目の中での連結領域とか大域的な情報を効率的に出すためのものと思えば良いのかな。集合に関するデータ(上の例では、その集合の要素数)をルートで一括して持っておくと同時に、各要素からルートへのアクセスを常に短く確保しておくことで、各要素と集合との関わりが簡単に引っ張り出せる、という感じか。