Notice
Recent Posts
Recent Comments
Link
«   2024/05   »
1 2 3 4
5 6 7 8 9 10 11
12 13 14 15 16 17 18
19 20 21 22 23 24 25
26 27 28 29 30 31
Archives
Today
Total
관리 메뉴

codingfarm

세그먼트 트리(Segment Tree) 본문

Algorithm & Data structure/이론

세그먼트 트리(Segment Tree)

scarecrow1992 2020. 9. 2. 22:44

세그먼트 트리는 배열의 구간합을 더 효율적으로 구하기 위한 알고리즘이다.

가령 임의의 배열이 주어질때

위 배열에서 길이 n만큼의 구간을 다 구하는데 걸리는 시간 복잡도는 $O(n)$ 이다.

하지만 매번 구간의 합을 구할때마다 해당길이 만큼 계속해서 반복하는것은 시간낭비이다.

더 효율적인 방법은 없을까?

 

임의의 배열 S가 주어질때 S[i]는 arr[0] ~ arr[i] 에 있는 모든 원소들의 합과 같다.

만약 $a \sim b$까지의 구간에 있는 원소들의 합을 구해야 한다면?

S[b]-S[a-1] 을 통해 구간의 합을 $O(1)$의 상수 시간만에 구할 수 있다.

하지만 이 방법은 배열 요소의 값이 변한다면 값을 수정하기가 굉장히 번거로워진다.

가령 a[i]의 값이 변한다면 S[i] 이후의 값들을 모두 수정시켜야한다.

그러므로 값을 수정하는데 걸리는 시간은 $O(n)$의 효율이 나온다.

 

그렇기에 구간합을 구하는데도, 값을 수정하는데도 최적의 성능이 나오는 알고리즘이 필요하다.

이것이 세그먼트 트리이다.

세그먼트 트리의 원리는 아래와 같다.

 

각 배열의 요소를 이진트리의 끝노드(leaf)로 삼은 후 트리 내의 각 노드는 왼쪽과 오른쪽 자식노드의 값을 더한것이라 가정한다.

그러면 구간 합을 구하는것은 아래처럼 할 수 있다.

 

배열의 길이가 6일때 세그먼트트리의 구조는 아래와 같다.

 

이진트리를 구현하는 방법은 연결리스트기반과 배열기반 2가지로 나뉘어 진다.

그러므로 세그먼트 트리도 마찬가지이다.

우선 연결리스트 기반부터 알아보자.

 

1. 연결리스트 기반

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#include<iostream>
 
using namespace std;
 
class SegmentTree
{
public:
    SegmentTree *pLeft;
    SegmentTree *pRight;
 
    int from, to;
    int value;
 
    SegmentTree()
    {
        pLeft = 0;
        pRight = 0;
        value = -1;
        from = -1;
        to = -1;
    }
 
    void Init(int *arr, int left, int right)
    {
        from = left;
        to = right;
 
        if (left == right)
        {
            value = arr[from];
        }
        else
        {
            int mid = left + (right - left) / 2;
 
            pLeft = new SegmentTree;
            pRight = new SegmentTree;
 
            pLeft->Init(arr, left, mid);
            pRight->Init(arr, mid + 1, right);
 
            value = pLeft->value + pRight->value;
        }
    }
 
    int GetValue(int left, int right)
    {
        if (left <= from && to <= right)
            return value;
        else if (to < left || from > right)
            return 0;
        else
        {
            return pLeft->GetValue(left, right) + pRight->GetValue(left, right);
        }
    }
 
    void Clear()
    {
        if (pLeft != 0)
            if (pLeft->from != pLeft->to)
                pLeft->Clear();
 
        if (pRight != 0)
            if (pRight->from != pRight->to)
                pRight->Clear();
 
        delete pLeft;
        delete pRight;
    }
 
    //구간내의 배열에 특정값을 더할경우
    int PlusValue(int left, int right, int changer)
    {
        if (to < left || from > right)
            return 0;
        else if (from == to)
        {
            value += changer;
            return changer;
        }
        else
        {
            int a = pLeft->PlusValue(left,right,changer);
            int b = pRight->PlusValue(left, right, changer);
            value += a + b;
            return a + b;
        }
    }
 
    //index의 값을 value로 바꾼다
    int SetValue(int index, int number)
    {
        if (index < from || to < index)
        {    
            return value;
        }
        else if(from == index && to == index)
        {
            value = number;
            return value;
        }
        else
        {
            int a = pLeft->SetValue(index, number);
            int b = pRight->SetValue(index, number);
            value = a + b;
            return value;
        }
 
    }
 
 
};
 
 
 
int main(void)
{
    int arr[10= { 0,1,2,3,4,5,6,7,8,9 };
    SegmentTree trees;
    trees.Init(arr, 09);
 
    
    cout << trees.GetValue(09<< endl;
    cout << trees.GetValue(37<< endl;
 
    trees.PlusValue(372);
 
    for (int i = 0; i < 10; i++)
        cout << trees.GetValue(i, i) << " ";
 
    cout << endl;
 
    cout << trees.GetValue(37<< endl;
    cout << trees.GetValue(09<< endl;
    cout << trees.GetValue(24<< endl;
    
 
    trees.SetValue(430);
 
    for (int i = 0; i < 10; i++)
        cout << trees.GetValue(i, i) << " ";
    
    cout << endl;
    cout << trees.GetValue(37<< endl;
    cout << trees.GetValue(09<< endl;
    cout << trees.GetValue(24<< endl;
 
    trees.Clear();
}
cs

 

2. 배열기반

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include<iostream>
 
using namespace std;
 
class SegmentTree {
public:
    int* nodes;
    int* arr;
    int size;
 
    int SetTree(int index, int from, int to) {
        if (from == to) {
            nodes[index] = arr[from];
            return arr[from];
        }
        
        int mid = (from + to) / 2;
        int value = SetTree(index * 2 + 1, from, mid) + SetTree(index * 2 + 2, mid + 1, to);
    
        nodes[index] = value;
        return value;
    }
 
    int SetValue_(int index, int from, int to, int key, int value) {
        if (from == to) {
            int diff = value - nodes[index];
            nodes[index] = value;
            return diff;
        }
 
        int mid = (from + to) / 2;
        int diff;
        if (key <= mid) 
            diff = SetValue_(index * 2 + 1, from, mid, key, value);
        else 
            diff = SetValue_(index * 2 + 2, mid + 1, to, key, value);
        
        nodes[index] += diff;
        return diff;
    }
 
    //arr[key]를 value로 바꾸었을때 트리 내부 값을 수정
    void SetValue(int key, int value) {
        SetValue_(00size - 1, key, value);
    }
 
    int GetValue_(int index, int from, int to, int beginint end) {
        //범위를 벗어날 경우
        if (to < begin || from > end)
            return 0;
 
        //범위에 들어올 경우
        if (begin <= from && to <= end)
            return nodes[index];
 
        int mid = (from + to) / 2;
        int ret = GetValue_(index * 2 + 1, from, mid, beginend+ GetValue_(index * 2 + 2, mid + 1, to, beginend);
        return ret;
    }
 
    //begin에서 end사이의 합을 반환한다.
    int GetValue(int beginint end) {
        return GetValue_(00size - 1beginend);
    }
 
    void Init(int *arr_, int size_) {
        arr = arr_;
        size = size_;
        nodes = new int[size * 3];
        SetTree(00size - 1);
    }
 
    void Delete() {
        delete[] nodes;
    }
};
 
int main(void) {
    int arr[16= { 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 };
 
    SegmentTree segTree;
    segTree.Init(arr, 16);
 
    cout << segTree.GetValue(312<< "\n";
    cout << segTree.GetValue(413<< "\n";
    cout << segTree.GetValue(812<< "\n";
    cout << segTree.GetValue(015<< "\n";
    cout << segTree.GetValue(11<< "\n";
 
    segTree.SetValue(0100);
    segTree.SetValue(80);
        
    return 0;
}
cs

 

 

 

참고할점은 세그먼트트리는 값의 수정과 합에 있어서 최적의 성능을 발휘할뿐

단순히 구간합을 구하기만 하는거라면 앞서 했던것처럼 0부터 특정 지점 까지의 합을 배열에 저장하는 방식이 더 빠르다.

상황에 따라 적절한 알고리즘을 쓰도록 하자.

 

 

참고 : www.acmicpc.net/blog/view/9

'Algorithm & Data structure > 이론' 카테고리의 다른 글

Union Find  (0) 2020.09.06
Merge Sort  (0) 2020.09.06
list(static allocation)  (0) 2020.09.06
list(dynamic allocation)  (0) 2020.09.05
vector  (0) 2020.09.05
Comments