본문 바로가기
컴퓨터공학/알고리즘 ˙ 자료구조

[자료구조] 세그먼트 트리을 이용해 구간 합 구하기

by 독서왕뼝아리 2023. 5. 21.
구간 합

구간 합을 구하기 위해 사용되는 방법은 세 가지이다.

 

1. for문으로 해결하기

int ans = 0;
for (int i=l; i<=r; i++) {
    ans += a[i];
}

구간 시작점(left)와 끝점(right)까지 직접 더해서 구하는 방법이다. O(N)의 시간복잡도를 가진다.

 

 

2. 누적 합 이용하기

int sum[100];
sum[0] = a[0];
for(int i=1; i<100; i++){
	sum[i] = sum[i-1] + a[i];
}

첫 인덱스부터 마지막 인덱스까지 누적된 합을 저장하는 배열을 만들어서 끝점과 시작점의 차를 이용해 구간 합을 구하는 방법이다. O(N)의 시간복잡도를 가진다.

 

 

3. 세그먼트 트리

트리 형태로 구간 합을 갖는 자료구조이다. 이진 트리이므로 O(logN)의 시간 복잡도만으로 구간 합을 만들 수 있다.

 

 

그럼 누적 합이랑 무엇이 다를까?

세그먼트 트리는 원본 배열 원소가 업데이트 되는 상황에서 매우 유리하다. 만약 누적합 배열을 만들었는데 중간의 한 원소값이 변경이 되었다면 이후 배열 모두 업데이트 해야 하기 때문에 시간이 오래 걸린다.

 

 

따라서 세그먼트 트리는 업데이트가 잦고, 배열이 무진장 클 때!! 유리한 자료구조 이다.

 


1. 세그먼트 트리 만들기

리프 노드를 제외한 다른 모든 노드는 항상 2개의 자식을 가진다. 따라서 세그먼트 트리는 정이진트리(Full Binary Tree)의 형태를 가진다. 만약 N이 2의 제곱꼴이면 포화이진트리(Perfect Binary Tree)가 된다.

void init(vector<long long> &a, vector<long long> &tree, int node, int start, int end) {
    if (start == end) { // 리프노드
        tree[node] = a[start];
    } else {
        init(a, tree, node*2, start, (start+end)/2);
        init(a, tree, node*2+1, (start+end)/2+1, end);
        tree[node] = tree[node*2] + tree[node*2+1];
    }
}

후위탐색으로 노드의 값을 저장해준다!

 

2. 구간 합 구하기

구간 left, right가 주어졌을 때 합을 구하려면 트리를 루트부터 순회하면서 각 노드에 저장된 구간의 정보와 left, right의 관계를 살펴봐야 한다.

각 노드의 저장된 구간이 [start, end]이고, 구간 합 구간을 [left, right]라고 하자. 다음과 같은 4가지 경우로 나누어진다.

 

(1) [left, right]와 [start, end]가 겹치지 않는 경우 (right<start || end<left)

탐색을 이어나갈 필요가 없기 때문에 0을 리턴한다.

 

(2) [left,right]가 [start,end]를 완전히 포함하는 경우 (left<=start && end<=right)

이 경우도 탐색을 이어나갈 필요가 없다. 어차피 자식 노드는 [left, right] 구간 안에 포함되기 때문이다. tree[node]를 리턴하고 탐색을 종료한다.

 

(3) [start,end]가 [left,right]를 완전히 포함하는 경우

(4) [left,right]와 [start,end]가 겹쳐져 있는 경우 (1, 2, 3 제외한 나머지 경우)

3번과 4번의 경우 정확한 구간을 알기 위해 자식 노드를 루트로 재귀호출해 탐색을 한다.

 

 

그래도 코드로 구현하면 다음과 같다.

long long query(vector<long long> &tree, int node, int start, int end, int left, int right) {
    if (left > end || right < start) {
        return 0;
    }
    if (left <= start && end <= right) {
        return tree[node];
    }
    long long lsum = query(tree, node*2, start, (start+end)/2, left, right);
    long long rsum = query(tree, node*2+1, (start+end)/2+1, end, left, right);
    return lsum + rsum;
}

 

3. 원소 업데이트

원본 배열의 index 번째 수를 val로 변경하는 경우, index 번째를 포함하는 노드에 들어있는 합만 변경해주면 된다! 또는 index 번째 리프노드를 찾아가 부모노드의 값을 변경해주는 방법도 있다.

 

나는 두 번째 방법이 편한다. 취향차이!

void update(vector<long long> &a, vector<long long> &tree, int node, int start, int end, int index, long long val) {
    if (index < start || index > end) {
        return;
    }
    if (start == end) {
        a[index] = val;
        tree[node] = val;
        return;
    }
    update(a, tree,node*2, start, (start+end)/2, index, val);
    update(a, tree,node*2+1, (start+end)/2+1, end, index, val);
    tree[node] = tree[node*2] + tree[node*2+1];
}

 

이렇게 세그먼트 트리의 기본 연산(초기화, 탐색, 업데이트)를 구현하였다. 업데이트가 잦은!!!!! 배열의 누적합을 구할 때 사용하자!!!!!!!!!!!!

 

 

참고

https://book.acmicpc.net/ds/segment-tree