Segment Tree
Updated:
Segment Tree
세그먼트 트리는 보통 업데이트가 있는 구간 쿼리 문제에서 사용되는 알고리즘이다. 구간 쿼리 문제란 말 그대로 배열에서 특정 구간에 대해 질문하는 문제이다. 구간 쿼리 문제의 예시는 다음과 같다.
배열 a가 주어졌을 때, 다음과 같은 구간 쿼리를 실행하여라.
1 i j: i번째 원소 ~ j번째 원소의 합을 출력한다.
2 i v: i번째 원소의 값을 v로 변경한다.
만약 단순 배열로 위 문제를 해결한다고 하면 시간복잡도는 $O(N)$이 나온다. 하지만 세그먼트 트리를 이용한다면 쿼리의 수가 $Q$라 했을 때, $O((N+Q)logN)$에 해결이 가능하다.
세그먼트 트리의 아이디어는 다음과 같다. 구간 쿼리 문제를 빠르게 해결하기 위해 모든 구간의 계산값을 미리 저장하는 대신 특정 구간만 미리 계산해두고 구간 쿼리가 들어왔을 때, 해당 구간에 맞게 이전에 계산 해둔 구간을 조합하는 것이다.
세그먼트 트리는 위와 같이 배열을 절반씩 쪼개가면서 계산한 구간합을 트리의 형태로 저장한다. 만약 1 4 7이라는 쿼리가 들어온다면, 7 + sum(5,6) + 4 를 계산하면 된다.
2번 쿼리의 경우 2 3 6의 쿼리가 들어온다면, 배열의 3번째 값의 노드와 그 상위 노드들만 업데이트 해주면 된다. 이러한 방식으로 두 가지 쿼리를 모두 log시간에 해결할 수 있다.
세그먼트 트리의 구현을 살펴보면, 우선 최초의 배열을 통해 세그먼트 트리를 초기화 하는 함수는 다음과 같다.
//arr은 입력받을 배열, tree는 구간합을 처리할 배열로 사이즈는 4*n이면 충분하다.
int tree[400001], arr[100001];
//node는 트리에서의 현재 노드의 인덱스. s,e는 arr에서 노드가 담당하는 부분의 좌측, 우측 끝 인덱스.
int init(int node, int s, int e){
if(s==e) return tree[node]=arr[s]; //leaf node
int m=s+e>>1;
return tree[node]=init(node*2,s,m)+init(node*2+1,m+1,e);
}
초기화 시에는 init(1,1,N)
으로 호출한다. 1번 쿼리의 구현은 다음과 같다.
int query(int node, int s, int e, int i, int j){
if(s>j || e<i) return 0; //탐색 범위를 벗어나면 0리턴
if(s>=i && e<=j) return tree[node]; //탐색 범위에 포함되면 해당 노드의 값 리턴
int m=s+e>>1;
return query(node*2,s,m,i,j)+query(node*2+1,m+1,e,i,j);
}
2번 업데이트 쿼리의 경우 다음과 같다.
long update(int node, int s, int e, int i, long v){
if(i<s || i>e) return tree[node]; //탐색하는 범위에 포함되지 않는다면, 해당 노드의 값 리턴
if(s==e) return tree[node]=v; //리프 노드라면 update 후 리턴
int m=s+e>>1;
return tree[node]=update(node*2,s,m,i,v)+update(node*2+1,m+1,e,i,v);
}
Merge Sort Tree
Merge Sort
머지 소트는 $O(NlogN)$의 시간복잡도를 가지는 정렬 방식의 일종으로 빠른 속도와 분할 정복을 사용한다.
정렬이 되지 않은 배열을 합치는 것은 복잡하지만, 정렬이 된 배열 두 개를 합쳐서 정렬된 상태로 만드는 것은 $O(N)$의 시간이면 충분하다. 정렬 방법은 다음과 같다.
int arr1[]={1,4,6,7}, arr2[]={2,3,4,8}, arr3[8];
for(int i=0,j=0,k=0; k<8; k++){
if(j==4 || (i<4 && arr1[i]<arr2[j])) arr3[k]=arr1[i++];
else arr3[k]=arr2[j++];
}
두 개의 배열에 각각 포인터를 지정하여 포인터가 가리키는 수 중에 더 작은 것을 취하고 해당 배열의 포인터를 1증가시키는 행위를 반복하면 된다.
분할 정복이란 하나의 문제를 여러개로 분할하여 해결한 후 합치는 방식으로 문제를 해결하는 것을 말한다.
위 알고리즘에 분할 정복을 적용하면 배열의 사이즈가 1일 때는 정렬이 이미 된 상태이고, 위 알고리즘을 통해 사이즈가 2인 정렬된 배열을 만들고, 다시 사이즈가 2인 배열을 합쳐서 사이즈가 4인 정렬된 배열을 만드는 식으로 반복하면 최종적으로 배열의 전체가 정렬된다.
최종적으로 머지 소트를 구현하면 다음과 같다.
void merge(int s, int e){
if(s==e) return;
int m=s+e>>1;
merge(s, m);
merge(m+1, e);
vector<int> tmp;
for(int i=s,j=m+1,k=0;k<=e-s;k++){
if(j>e || (i<=m && arr[i]<arr[j])) tmp.push_back(arr[i++]);
else tmp.push_back(arr[j++]);
}
for(int i=0; i<=e-s; i++) arr[s+i]=tmp[i];
return;
}
Merge Sort Tree
이제 머지 소트 트리에 대해 알아보자. 머지 소트 트리는 주로 업데이트가 없는 쿼리 문제에 사용되는 알고리즘으로 보통 구간에서 특정한 값 이상(이하, 초과 등)의 수가 몇 개인지 구하는 쿼리를 해결할 때 사용되고, 쿼리당 $O(logN)$의 시간에 해결이 가능하다.
머지 소트 트리의 형태는 앞서 살펴본 머지 소트 과정 그림 그대로 이다. 즉, 머지 소트 중 정렬된 중간 과정의 배열을 트리에 저장하기만 하면 된다. 세그먼트 트리와 동일하게 노드 자신의 왼쪽 자식 노드는 인덱스 * 2, 오른쪽 자식 노드는 인덱스 * 2 + 1을 한다.
머지 소트 트리 문제의 쿼리 예시는 다음과 같다.
i j v: i번째 원소 ~ j번째 원소 중 v이상인 수의 개수를 출력해라.
만약 3 6 4라는 쿼리가 들어왔다고 하자. 머지 소트의 노드들이 정렬되어 있다는 점을 이용해 이분 탐색을 활용할 수 있다. 이분 탐색을 통해 lower bound(4 이상인 수 중에 가장 작은 것의 위치)를 찾으면 log시간에 4이상의 수가 몇 개인지 알 수 있다.
머지 소트 트리 구현에 앞서 merge와 lower_bound함수를 살펴보면 다음과 같다.
merge(s1,e1,s2,e2,s3)
s1,s2: 합칠 배열 1,2의 시작 포인터
e1,e2: 합칠 배열 1,2의 끝 포인터
lower_bound(s,e,v)
s: 이분 탐색할 배열의 시작 포인터
e: 이분 탐색할 배열의 끝 포인터
v: 이분 탐색에 사용할 비교값
lower_bound는 v이상인 수 중에 가장 작은 것의 포인터를 반환한다. 당연히 사용할 배열은 정렬되어 있어야 한다. 문제 조건에 따라 lower_bound를 쓰기도 하고 upper_bound를 쓰기도 한다.
머지 소트 트리의 초기화는 다음과 같이 구현된다.
void init(int node, int s, long e){
if(s==e){
tree[node].push_back(arr[s]);
return;
}
int m=s+e>>1;
init(node*2,s,m);
init(node*2+1,m+1,e);
tree[node].resize(tree[node*2].size()+tree[node*2+1].size());
merge(tree[node*2].begin(), tree[node*2].end(),
tree[node*2+1].begin(), tree[node*2+1].end(), tree[node].begin()
);
return;
}
다음으로 쿼리는 아래와 같다.
int query(int node, int s, int e, int i, int j, long v){
if(s>j || e<i) return 0;
if(s>=i && e<=j) return tree[node].end()-upper_bound(tree[node].begin(), tree[node].end(), v);
int m=s+e>>1;
return query(node*2,s,m,i,j,v)+query(node*2+1,m+1,e,i,j,v);
}
Persistant Segment Tree
Persistant Segment Tree는 세그먼트 트리에 업데이트가 일어나더라도 업데이트 이전의 트리 상태를 전부 기록하는 세그먼트 트리이다.
다만 세그먼트 트리 전체를 업데이트 할 때마다 저장하면 용량이 매우 커지므로 실제로 업데이트 되는 노드에 한해서만 기록한다.
위 처럼 3번째 원소를 업데이트 한다면 총 4개의 노드만 값이 변경되므로 해당 노드들만 추가하고 나머지 점선으로 된 노드는 재활용하면 된다.
이를 위해 이전의 세그먼트 트리 방식과 다르게 포인터를 이용해 이전에 사용한 노드를 재활용하고 업데이트한 노드만 새로 만들어 사용한다. 이러한 방식을 동적 세그먼트 트리라고 한다.
Persistant Segment Tree의 쿼리 예시는 다음과 같다.
1 i j c: c번째 구간 쿼리까지 진행된 상태에서 i번째 원소 ~ j번째 원소의 합을 출력한다.
2 i v: i번째 원소를 v로 변경한다.
노드로 사용할 구조체는 다음과 같다.
struct Node{
int sum;
Node *l, *r, *shadow; //왼쪽, 오른쪽 자식 노드 포인터, 이전 쿼리에서의 노드 자신의 포인터
};
초기화는 일반 세그먼트 트리와 동일하다.
Node* tree[100001]; //tree[c]는 c번째 쿼리가 진행된 트리 루트의 포인터
int arr[100001];
int init(Node* node, int s, int e){
if(s==e) return node->sum=arr[s];
int m=s+e>>1;
node->l=new Node; //왼쪽 자식노드 생성
node->r=new Node; //오른쪽 자식노드 생성
return node->sum=init(node->l,s,m)+init(node->r,m+1,e);
}
업데이트는 다음과 같다. 업데이트 시 값이 갱신되는 노드만 새로 생성되어 해당 노드의 이전 쿼리 노드 즉, shadow와 연결되고 값이 갱신되지 않는 노드는 새로 생성할 필요 없이 이전 쿼리 노드인 shadow에 바로 연결해준다.
int update(Node* node, int s, int e, int i, int v){
if(s==e) return node->sum=v; //리프 노드일 경우 값을 업데이트하고 리턴
int m=s+e>>1;
if(i>m){ //업데이트 할 노드가 오른쪽인 경우
node->r=new Node; //오른쪽 자식 노드 생성
node->r->shadow=node->shadow->r; //오른쪽 자식 노드의 과거 노드 포인터 초기화
node->l=node->shadow->l; //왼쪽 자식 노드는 과거 노드의 왼쪽 노드에 연결
return node->sum=node->l->sum+update(node->r,m+1,e,i,v);
}
else{
node->l=new Node;
node->l->shadow=node->shadow->l;
node->r=node->shadow->r;
return node->sum=node->r->sum+update(node->l,s,m,i,v);
}
}
마지막으로 합 쿼리이다. 쿼리는 세그먼트 트리와 동일하다.
int query(Node* node, int s, int e, int i, int j){
if(i>e || j<s) return 0;
if(s>=i && e<=j) return node->sum;
int m=s+e>>1;
return query(node->l,s,m,i,j)+query(node->r,m+1,e,i,j);
}
문제 해결을 위한 메인 함수는 다음과 같다.
int main()
{
int N, Q, t1, t2, t3, t4, U=0;
cin >> N;
for(int i=1; i<=N; i++) cin >> arr[i];
cin >> Q;
tree[0]=new Node;
init(tree[0],1,N);
for(int i=1; i<=Q; i++){
cin >> t1 >> t2 >>t3;
if(t1==1){
cin >> t4;
cout << query(tree[t4],1,N,t2,t3) << endl;
}
else{
U++;
tree[U]=new Node;
tree[U]->shadow=tree[U-1];
update(tree[U],1,N,t2,t3);
}
}
return 0;
}
업데이트를 할 때마다 새로운 노드를 만들고 이전 노드와 shadow를 통해 연결하는 것을 볼 수 있다. 이후는 앞서 살펴보았듯이 갱신 되는 노드일 경우만 새로운 노드를 만들어 shadow끼리 연결하고 아니면 shadow의 자식 노드를 그대로 가져온다.
Lazy Propagation Segment Tree
Lazy Propagation Segment Tree의 쿼리는 다음과 같다.
1 i j: i ~ j번째 원소의 합을 출력한다.
2 i j v: i ~ j번째 원소에 v를 더한다.
세그먼트 트리와 비슷하지만 업데이트를 구간으로 진행한다는 점이 다르다. 이 경우 세그먼트 트리의 업데이트를 여러번 사용하면 쿼리당 $O(N \log N)$의 시간이 걸릴 수 있다.
Lazy Propagation Segment Tree의 아이디어는 세그먼트 트리에서 구간합을 구할 때 여러 구간들을 합쳐서 계산하는 것처럼 업데이트도 구간 전체를 담당하는 노드에만 해두고 필요할 때 즉, 하위 노드를 방문할 때 업데이트를 하는 것이다.
이를 위해 각 노드들에 lazy라는 값을 추가해준다.
위 트리에서 2 1 4 1라는 쿼리가 들어왔다고 하자.
위와 같이 sum(1, 8)
, sum(1, 4)
두 노드에만 업데이트를 진행해준다. 구간 전체에 1을 더하는 쿼리이기 때문에 구간의 길이만 안다면 굳이 하위 노드에 가지 않아도 값을 알 수 있다. 쿼리 구간이 1 ~ 4이므로 이 구간을 담당하는 노드의 하위 노드에는 업데이트 대신 lazy값을 추가해준다.
이제 1 4 5라는 쿼리가 들어왔다고 하자.
트리의 우측 부분의 경우 lazy가 없으므로 원래대로 5번째 인덱스를 담당하는 노드에 접근하면 된다. 하지만 4번째 인덱스를 담당하는 노드의 경우 상위 노드 sum(3,4)
에 lazy가 존재하므로 해당 노드에서 4번째 인덱스를 담당하는 노드로 바로 내려오는게 아니라 lazy를 하위 노드에 전파(propagation)하고 lazy 값과 범위에 맞게 해당 노드의 값도 수정해준 뒤 내려온다.
마지막으로 4번째 인덱스를 담당하는 노드로 내려오면 lazy가 존재하는데 해당 노드는 하위 노드가 없으므로 lazy값만 더해주고 리턴한다.
이렇게 하위 노드의 값들을 느리게 갱신해주는 세그먼트 트리를 Lazy Propagation Segment Tree라고 한다.
Lazy Propagation Segment Tree의 초기화 구현은 일반 세그먼트 트리와 동일하므로 생략한다.
lazy값을 하위 노드로 propagation하는 함수와 update는 다음과 같다.
long lazy[4000001];
void prop(int node, int s, int e){
tree[node]+=lazy[node]*(e-s+1); //해당 노드의 값에 lazy값을 업데이트
if(s!=e) for(int i : {node*2, node*2+1}) lazy[i]+=lazy[node]; //lazy값을 자식 노드에 propagation
lazy[node]=0;
return;
}
long update(int node, int s, int e, int i, int j, long v){
if(lazy[node]) prop(node,s,e); //lazy값이 존재한다면 propagation
if(i>e || j<s) return tree[node]; //탐색범위에 포함되지 않는다면 노드값 리턴
if(i<=s && e<=j){ //탐색하는 범위에 완전히 포함되는 경우
lazy[node]=v; //lazy값 갱신
prop(node,s,e); //lazy값 propagation
return tree[node]; //node값 리턴
}
int m=s+e>>1;
return tree[node]=update(node*2,s,m,i,j,v)+update(node*2+1,m+1,e,i,j,v);
}
update를 보면 탐색범위에 완전히 포함되는 경우에만 lazy값을 갱신하는 것을 볼 수 있다. 따라서 prop함수가 호출될 때는 좌우 노드가 모두 propagation 대상이 된다.
쿼리의 경우 세그먼트 트리와 거의 동일하다.
long query(int node, int s, int e, int i, int j){
if(lazy[node]) prop(node,s,e); //lazy값이 존재한다면 propagation
if(i>e || j<s) return 0; //탐색하는 범위에 포함되지 않는다면 0리턴
if(i<=s && e<=j) return tree[node]; //탐색하는 범위에 완전히 포함된다면 노드값 리턴
int m=s+e>>1;
return query(node*2,s,m,i,j)+query(node*2+1,m+1,e,i,j);
}
Reference
공군 휴머니스트 air-wiki
Leave a comment