Mo's algorithm (모스 알고리즘)

 

모스 알고리즘은 오프라인 쿼리 문제에서 쿼리의 순서를 최적화 시키는 알고리즘이다.


이는 전에 포스팅한 sqrt decomposition을 이용하여 구현하게 되는데 다음과 같은 문제를 생각해보자.


백준 13547


수열의 크기가 $N$이고, 쿼리의 갯수가 $M$일 때, 쿼리를 입력받은 대로 순차적으로 처리한다면 $O(NM)$으로 시간초과가 발생한다. 

그렇다면 우리가 이전 쿼리의 구간에서 사용했던 정보를 이용하여 현재 쿼리 구간의 정보를 더해 사용하는 것은 어떨까?


예를 들어  다음과 같은 식이다.


Qi에서  100 ... 10000 번째 까지의 $A_i$에 대하여 다른 수가 몇 번 등장했는지를 뽑아낸다면, 이 때 뽑아냈던 정보를 이용해


Qi_1에서 200 ... 9000 번째 까지의 $A_i$에 대하여 다른 수가 몇 번 등장했는지를 뽑아내는 것이다.


이는 100을 가르키는 왼쪽 포인터가 200까지 이동하면서 순차적으로 cnt 배열의 $cnt[Arr[l]]$을 -1로 update하고 또한 10000을 가르키는 오른쪽 포인터가 9000까지 이동하며 $cnt[Arr[l]]$을 -1로 update하면 정보를 이용하여 구할 수 있을 것이다.

이 때 cnt 배열이 0이 된다면 ans 값을 -1 하면 값 종류의 갯수가 하나 적어진다는 것을 뜻하게 되고,

cnt 배열이 0에서 0이 아닌 값이 된다면 ans 값을 1하면 값 종류 갯수가 하나 늘어난다는 것을 뜻하게 된다. 


이를 바탕으로 구현하면 다음과 같이 코드를 작성할 수 있다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
struct query {
    int left, right, index;
} Q[100001];
 
void f(int index, bool add) {
    if(add) {
        if(cnt[arr[index]] == 0) {
            disNum++;
        }
        cnt[arr[index]]++;
    }
    else {
        cnt[arr[index]]--;
        if(cnt[arr[index]] == 0) {
            disNum--;
        }
    }
}
 
...
 
 
int main() {
    ...
    int lo = 0, hi = 0;
    for(int i=1; i<=M; i++) {
        while(Q[i].right > hi) {
            f(++hi, true);
        }
        while(Q[i].right < hi) {
            f(hi--false);
        }
        while(Q[i].left < lo) {
            f(--lo, true);
        }
        while(Q[i].left > lo) {
            f(lo++false);
        }
        ans[Q[i].index] = disNum;
    }
 
    for(int i=1; i<=M; i++printf("%d\n", ans[i]);
}
cs


하지만 이것으로 불충분한것이, offline query를 정렬하지 않고 사용할 경우 최악의 시간복잡도가 

$O(NM)$이다. 왜냐하면 Q1이 (1, 1)까지의 query이고 Q2가 (N, N)까지의 query일 경우 포인터가 $O(N)$ 움직여야하고 이런 것이 M번 반복되면 $O(NM)$이기 때문이다.


그러면 offline query를 어떻게 정렬해야 할까?

여기서 전에 살펴본 sqrt decomposition을 생각해보자. 전체 N개의 블록에 대해서 $\sqrt{N}$개의 블록으로 쪼갠후 다음과 같이 정렬한다고 생각하자.


1
2
3
4
5
6
7
8
bool cmp(const query& a, const query& b) {
    int sq = (int)sqrt(N);
 
    int aSR = a.right / sq;
    int bSR = b.right / sq;
    
    return aSR == bSR ? a.left < b.left : aSR < bSR;
}
cs


query a의 오른쪽 포인터, query b의 오른쪽 포인터에 대하여 sq_block의 순서를 통해 query를 정렬한다.

만약 sq_block이 같을 경우  왼쪽 포인터를 비교한다.


자 이렇게 되면 먼저 오른쪽 포인터의 움직임을 생각하자.


오른쪽 포인터의 움직임은 첫째, 같은 sq_block내에서 움직이거나, 다른 블록으로 넘어가는 것임.

같은 sq_block내에서 움직일 경우, 이 경우는 모든 Query M에 대해서 $O(\sqrt{N})$만큼 움직이므로 $O(M\sqrt{N})$만큼 이동이 발생함.


자 그럼 다른 블록으로 넘어가는 경우는 어떨까?

한 번에 최대 $O(N)$칸을 움직여야 할수도 있지만, 전체 쿼리를 수행하는 데에 있어 블록 사이를 건너가는 이동 횟수도 총 $O(N)$이다. 이게 매우 중요한 포인트인데, 다른 block으로 넘어갈 경우, 이전에 있던 block으로 절대 돌아오지 않으므로 $O(N)$이 된다.


정리하면 M개의 쿼리에 대해 매 쿼리당 $O(\sqrt{N})$번을 움직이고 전체과정에서 $O(N)$번의 이동이 있으니 $O(N + M\sqrt{N})$번의 이동만 있게 된다.


왼쪽 포인터의 이동 같은 경우에는 r이 이전 쿼리와 같은 블록에 속해 있는 동안 총 $O(N)$번 움직인다.

r이 이전 쿼리와 다른 블록에 속해있을 경우가 $O(\sqrt{N})$으로, 이 경우 left pointer가 1로 초기화가 되기 때문에 총 $O(N\sqrt{N})$만큼 움직이게 된다.


따라서, 모스 알고리즘을 이용하면 $O((N+M)\sqrt{N})$ 시간에 문제를 풀 수 있다.

존내 개꿀이다.



1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include <stdio.h>
#include <math.h>
#include <algorithm>
 
using namespace std;
 
int N, M;
int cnt[1000001];
int arr[100001];
int ans[100001];
int disNum = 0;
 
struct query {
    int left, right, index;
} Q[100001];
 
bool cmp(const query& a, const query& b) {
    int sq = (int)sqrt(N);
 
    int aSR = a.right / sq;
    int bSR = b.right / sq;
    
    return aSR == bSR ? a.left < b.left : aSR < bSR;
}
 
void f(int index, bool add) {
    if(add) {
        if(cnt[arr[index]] == 0) {
            disNum++;
        }
        cnt[arr[index]]++;
    }
    else {
        cnt[arr[index]]--;
        if(cnt[arr[index]] == 0) {
            disNum--;
        }
    }
}
 
int main() {
    scanf("%d"&N);
    for(int i=1; i<=N; i++) {
        scanf("%d"&arr[i]);
    }
    scanf("%d"&M);
    for(int i=1; i<=M; i++) {
        scanf("%d%d"&Q[i].left, &Q[i].right);
        Q[i].index = i;
    }
    sort(Q+1, Q+M+1, cmp);
    int lo = 0, hi = 0;
    for(int i=1; i<=M; i++) {
        while(Q[i].right > hi) {
            f(++hi, true);
        }
        while(Q[i].right < hi) {
            f(hi--false);
        }
        while(Q[i].left < lo) {
            f(--lo, true);
        }
        while(Q[i].left > lo) {
            f(lo++false);
        }
        ans[Q[i].index] = disNum;
    }
 
    for(int i=1; i<=M; i++printf("%d\n", ans[i]);
}
cs

댓글