Mo's algorithm (모스 알고리즘)

 

모스 알고리즘은 오프라인 쿼리 문제에서 쿼리의 순서를 최적화 시키는 알고리즘이다.


이는 전에 포스팅한 sqrt decomposition을 이용하여 구현하게 되는데 다음과 같은 문제를 생각해보자.


백준 13547


수열의 크기가 N이고, 쿼리의 갯수가 M일 때, 쿼리를 입력받은 대로 순차적으로 처리한다면 O(NM)으로 시간초과가 발생한다. 

그렇다면 우리가 이전 쿼리의 구간에서 사용했던 정보를 이용하여 현재 쿼리 구간의 정보를 더해 사용하는 것은 어떨까?


예를 들어  다음과 같은 식이다.


Qi에서  100 ... 10000 번째 까지의 A_i에 대하여 다른 수가 몇 번 등장했는지를 뽑아낸다면, 이 때 뽑아냈던 정보를 이용해


Qi_1에서 200 ... 9000 번째 까지의 A_i에 대하여 다른 수가 몇 번 등장했는지를 뽑아내는 것이다.


이는 100을 가르키는 왼쪽 포인터가 200까지 이동하면서 순차적으로 cnt 배열의 cnt[Arr[l]]을 -1로 update하고 또한 10000을 가르키는 오른쪽 포인터가 9000까지 이동하며 cnt[Arr[l]]을 -1로 update하면 정보를 이용하여 구할 수 있을 것이다.

이 때 cnt 배열이 0이 된다면 ans 값을 -1 하면 값 종류의 갯수가 하나 적어진다는 것을 뜻하게 되고,

cnt 배열이 0에서 0이 아닌 값이 된다면 ans 값을 1하면 값 종류 갯수가 하나 늘어난다는 것을 뜻하게 된다. 


이를 바탕으로 구현하면 다음과 같이 코드를 작성할 수 있다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
struct query {
    int left, right, index;
} Q[100001];
 
void f(int index, bool add) {
    if(add) {
        if(cnt[arr[index]] == 0) {
            disNum++;
        }
        cnt[arr[index]]++;
    }
    else {
        cnt[arr[index]]--;
        if(cnt[arr[index]] == 0) {
            disNum--;
        }
    }
}
 
...
 
 
int main() {
    ...
    int lo = 0, hi = 0;
    for(int i=1; i<=M; i++) {
        while(Q[i].right > hi) {
            f(++hi, true);
        }
        while(Q[i].right < hi) {
            f(hi--false);
        }
        while(Q[i].left < lo) {
            f(--lo, true);
        }
        while(Q[i].left > lo) {
            f(lo++false);
        }
        ans[Q[i].index] = disNum;
    }
 
    for(int i=1; i<=M; i++printf("%d\n", ans[i]);
}
cs


하지만 이것으로 불충분한것이, offline query를 정렬하지 않고 사용할 경우 최악의 시간복잡도가 

O(NM)이다. 왜냐하면 Q1이 (1, 1)까지의 query이고 Q2가 (N, N)까지의 query일 경우 포인터가 O(N) 움직여야하고 이런 것이 M번 반복되면 O(NM)이기 때문이다.


그러면 offline query를 어떻게 정렬해야 할까?

여기서 전에 살펴본 sqrt decomposition을 생각해보자. 전체 N개의 블록에 대해서 \sqrt{N}개의 블록으로 쪼갠후 다음과 같이 정렬한다고 생각하자.


1
2
3
4
5
6
7
8
bool cmp(const query& a, const query& b) {
    int sq = (int)sqrt(N);
 
    int aSR = a.right / sq;
    int bSR = b.right / sq;
    
    return aSR == bSR ? a.left < b.left : aSR < bSR;
}
cs


query a의 오른쪽 포인터, query b의 오른쪽 포인터에 대하여 sq_block의 순서를 통해 query를 정렬한다.

만약 sq_block이 같을 경우  왼쪽 포인터를 비교한다.


자 이렇게 되면 먼저 오른쪽 포인터의 움직임을 생각하자.


오른쪽 포인터의 움직임은 첫째, 같은 sq_block내에서 움직이거나, 다른 블록으로 넘어가는 것임.

같은 sq_block내에서 움직일 경우, 이 경우는 모든 Query M에 대해서 O(\sqrt{N})만큼 움직이므로 O(M\sqrt{N})만큼 이동이 발생함.


자 그럼 다른 블록으로 넘어가는 경우는 어떨까?

한 번에 최대 O(N)칸을 움직여야 할수도 있지만, 전체 쿼리를 수행하는 데에 있어 블록 사이를 건너가는 이동 횟수도 총 O(N)이다. 이게 매우 중요한 포인트인데, 다른 block으로 넘어갈 경우, 이전에 있던 block으로 절대 돌아오지 않으므로 O(N)이 된다.


정리하면 M개의 쿼리에 대해 매 쿼리당 O(\sqrt{N})번을 움직이고 전체과정에서 O(N)번의 이동이 있으니 O(N + M\sqrt{N})번의 이동만 있게 된다.


왼쪽 포인터의 이동 같은 경우에는 r이 이전 쿼리와 같은 블록에 속해 있는 동안 총 O(N)번 움직인다.

r이 이전 쿼리와 다른 블록에 속해있을 경우가 O(\sqrt{N})으로, 이 경우 left pointer가 1로 초기화가 되기 때문에 총 O(N\sqrt{N})만큼 움직이게 된다.


따라서, 모스 알고리즘을 이용하면 O((N+M)\sqrt{N}) 시간에 문제를 풀 수 있다.

존내 개꿀이다.



1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include <stdio.h>
#include <math.h>
#include <algorithm>
 
using namespace std;
 
int N, M;
int cnt[1000001];
int arr[100001];
int ans[100001];
int disNum = 0;
 
struct query {
    int left, right, index;
} Q[100001];
 
bool cmp(const query& a, const query& b) {
    int sq = (int)sqrt(N);
 
    int aSR = a.right / sq;
    int bSR = b.right / sq;
    
    return aSR == bSR ? a.left < b.left : aSR < bSR;
}
 
void f(int index, bool add) {
    if(add) {
        if(cnt[arr[index]] == 0) {
            disNum++;
        }
        cnt[arr[index]]++;
    }
    else {
        cnt[arr[index]]--;
        if(cnt[arr[index]] == 0) {
            disNum--;
        }
    }
}
 
int main() {
    scanf("%d"&N);
    for(int i=1; i<=N; i++) {
        scanf("%d"&arr[i]);
    }
    scanf("%d"&M);
    for(int i=1; i<=M; i++) {
        scanf("%d%d"&Q[i].left, &Q[i].right);
        Q[i].index = i;
    }
    sort(Q+1, Q+M+1, cmp);
    int lo = 0, hi = 0;
    for(int i=1; i<=M; i++) {
        while(Q[i].right > hi) {
            f(++hi, true);
        }
        while(Q[i].right < hi) {
            f(hi--false);
        }
        while(Q[i].left < lo) {
            f(--lo, true);
        }
        while(Q[i].left > lo) {
            f(lo++false);
        }
        ans[Q[i].index] = disNum;
    }
 
    for(int i=1; i<=M; i++printf("%d\n", ans[i]);
}
cs

댓글

가장 많이 본 글