본문 바로가기
📊 Algorithm/Algorithm 주제별 정리

🧚‍♂️알고리즘🧚‍♂️ - 🌴 - 세그먼트 트리

by 정람지 2023. 7. 31.

🌴세그먼트 트리🌴

주어진 데이터의 구간 합과 데이터 업데이트를 빠르게 구현하기 위한 자료구조

(큰 범위 세그먼트 트리 == 인덱스 트리)

 

세그먼트 트리 종류 : 구간 합 / 최대 최소 구하기

 

🌴구현 단계🌴

1. 트리 초기화하기

트리 리스트의 크기 : 2 ^(k+1) ( 2^k >= N(노드수) 를 만족하는 k의 최솟값 )

 

2. 질의값 구하기 (구간 합 또는 최소/최대)

 

+ 원래 노드 인덱스를 세그먼트 트리 인덱스로 변경하기 :

세그먼트 트리 인덱스 = 주어진 질의 인덱스 + 리프 노드 시작 인덱스(2^k -1)

질의값 구하기

1) 시작인덱스 % 2 == 1일 때 해당 노드 선택
2) 끝인덱스 % 2 == 0일 때 해당 노드 선택
3) 시작인덱스 = (시작인덱스 + 1) / 2 (시작인덱스 깊이 변경)
4) 끝인덱스 = (끝인덱스 - 1) / 2 (끝인덱스 깊이 변경)
5) 1-4 반복, 끝인덱스 < 시작인덱스 시 종료
def getSeg(s,e):
    part_seg = 0 
    while(s <= e):
        if (s % 2 == 1):
            part_seg += tree[s] #구간 합
            part_seg = min(part_seg,tree[s]) #최소
            part_seg = max(part_seg,tree[s]) #최대
            s += 1
        if (e % 2 == 0):
            part_seg += tree[e] #구간 합
            part_seg = min(part_seg,tree[e]) #최소
            part_seg = max(part_seg,tree[e]) #최대
            e -= 1
        
        s //= 2
        e //= 2
    return part_seg

 

 

3. 데이터 업데이트하기

!자신의 부모 노드로 계속 이동하면서 업데이트

구간 합: 원래 데이터와 변경 데이터의 차이만큼 업데이트

# 값 변경 함수!
def changeVal(index, value):
    diff = value - tree[index] # 현재 노드의 값과 변경된 값의 차이
    while (index > 0):
        tree[index] = tree[index] + diff
        index = index // 2 # 현재 노드부터 최고 부모 노드까지 쭉 손보기

최댓/최솟값 : 변경 데이터와 다른 자식과의 비교 통해 업데이트


이제 클린 코드처럼 쓰도록 노력하고..

변수는 스네이크 함수는 캐멀 상수는 대문자 쓴다

 

🥇골드 1

 

2042번: 구간 합 구하기

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄

www.acmicpc.net

구간합(합 배열)은 데이터의 변경이 빈번할 시 시간이 오래 걸림

세그먼트 트리 이용하기

 

pow(a,b) 함수 : a^b 계산 

 

세그먼트 트리는 1부터 시작하는 것이 계산에 편리 ( 인덱스 * 2는 왼쪽 자식 노드가 바로 되므로)

from sys import stdin
N,M,K= map(int, stdin.readline().split()) # 수 개수 , 변경 횟수, 구간합 구하기 횟수

tree_height = 0
length = N

#트리의 높이 구하기 (k 구하기)
while (length != 0):
    length //= 2 # 리프 노드의 개수를 2씩 나누어 가면서 높이 계산
    tree_height += 1

tree_size = pow(2,tree_height+1) # 트리리스트 크기 구하기
tree = [0] * (tree_size + 1) # 세그먼트 트리는 1부터 시작하는 것이 계산에 편리
left_node_start_index = tree_size // 2 -1 # 리프 노드 시작 인덱스

# 데이터 리프 노드 저장하기
for i in range(left_node_start_index + 1, left_node_start_index + N + 1):
    tree[i] = int(stdin.readline())

# 세그먼트 트리 생성 함수!
def setTree(i):
    while(i != 1):
        tree[i // 2] += tree[i] # 부모 노드에 현재 인덱스의 크기 더하기
        i -= 1
    
setTree(tree_size - 1) # 초기 세그먼트 트리 세팅하기

# 값 변경 함수!
def changeVal(index, value):
    diff = value - tree[index] # 현재 노드의 값과 변경된 값의 차이
    while (index > 0):
        tree[index] = tree[index] + diff
        index = index // 2 # 현재 노드부터 최고 부모 노드까지 쭉 손보기

# 구간 합 계산 함수!
def getSum(s,e):
    part_sum = 0 
    while(s <= e):
        if (s % 2 == 1): #시작인덱스 % 2 == 1일 때 해당 노드 선택
            part_sum += tree[s] # 해당 노드의 값을 구간 합에 추가
            s += 1
        if (e % 2 == 0): #끝인덱스 % 2 == 0일 때 해당 노드 선택
            part_sum += tree[e]
            e -= 1
        
        s //= 2 #시작인덱스 = (시작인덱스 + 1) / 2 (시작인덱스 깊이 변경)
        e //= 2 #끝인덱스 = (끝인덱스 - 1) / 2 (끝인덱스 깊이 변경)
    return part_sum

# 결과
for _ in range(M + K):
    ques,s,e = map(int, stdin.readline().split())

    if (ques == 1): # 값 변경
        s += left_node_start_index
        changeVal(s,e)
    else: # 구간합 계산
        s += left_node_start_index #트리에서 원래 데이터 노드들은 리프에 위치해 있어 left~ 값을 더해주어야 함 주의
        e += left_node_start_index
        print(getSum(s,e))

🥇골드 1

 

10868번: 최솟값

N(1 ≤ N ≤ 100,000)개의 정수들이 있을 때, a번째 정수부터 b번째 정수까지 중에서 제일 작은 정수를 찾는 것은 어려운 일이 아니다. 하지만 이와 같은 a, b의 쌍이 M(1 ≤ M ≤ 100,000)개 주어졌을 때는

www.acmicpc.net

입력값들 크기 받기-> '2^k >=N' 만족 k값 찾기-> 트리 리스트 크기 구하기 / (리프 노드)시작 인덱스 값 구하기'

이번엔 최솟값!

from sys import stdin
N,M = map(int, stdin.readline().split()) # 수 개수 , 최솟값 구하기 횟수

tree_height = 0
length = N

#트리의 높이 구하기 (k 구하기)
while (length != 0):
    length //= 2 # 리프 노드의 개수를 2씩 나누어 가면서 높이 계산
    tree_height += 1

tree_size = pow(2,tree_height+1) # 트리리스트 크기 구하기
tree = [2**31-1] * (tree_size + 1) # 세그먼트 트리는 1부터 시작하는 것이 계산에 편리
left_node_start_index = tree_size // 2 -1 # 리프 노드 시작 인덱스

# 데이터 리프 노드 저장하기
for i in range(left_node_start_index + 1, left_node_start_index + N + 1):
    tree[i] = int(stdin.readline())

# 세그먼트 트리 생성 함수!
def setTree(i):
    while(i != 1):
        # 부모 노드, 현재 인덱스 중 최솟값 넣기
        tree[i // 2] = min(tree[i // 2], tree[i])
        i -= 1
    
setTree(tree_size - 1) # 초기 세그먼트 트리 세팅하기

# 최솟값 구하기 함수!
def getMin(s,e):
    part_Min = 2**31-1
    while(s <= e):
        if (s % 2 == 1): #시작인덱스 % 2 == 1일 때 해당 노드 선택
            part_Min = min(part_Min, tree[s]) # 해당 노드의 값을 구간 합에 추가
            s += 1
        if (e % 2 == 0): #끝인덱스 % 2 == 0일 때 해당 노드 선택
            part_Min = min(part_Min, tree[e])
            e -= 1
        
        s //= 2 #시작인덱스 = (시작인덱스 + 1) / 2 (시작인덱스 깊이 변경)
        e //= 2 #끝인덱스 = (끝인덱스 - 1) / 2 (끝인덱스 깊이 변경)

    return part_Min

# 결과
for _ in range(M):
    s,e = map(int, stdin.readline().split())
    print(getMin(s+left_node_start_index,e+left_node_start_index))#left_node_start_index주의!!!

🥇골드 1

 

11505번: 구간 곱 구하기

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 곱을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄

www.acmicpc.net

구간합에서 변경된 구간곱

값 갱신도 있다

from sys import stdin
N,M,K= map(int, stdin.readline().split()) # 수 개수 , 변경 횟수, 구간합 구하기 횟수

tree_height = 0
length = N

#트리의 높이 구하기 (k 구하기)
while (length != 0):
    length //= 2 # 리프 노드의 개수를 2씩 나누어 가면서 높이 계산
    tree_height += 1

tree_size = pow(2,tree_height+1) # 트리리스트 크기 구하기
tree = [1] * (tree_size + 1) # 세그먼트 트리는 1부터 시작하는 것이 계산에 편리
left_node_start_index = tree_size // 2 -1 # 리프 노드 시작 인덱스

# 데이터 리프 노드 저장하기
for i in range(left_node_start_index + 1, left_node_start_index + N + 1):
    tree[i] = int(stdin.readline())

# 세그먼트 트리 생성 함수!
def setTree(i):
    while(i != 1):
        tree[i // 2] *= tree[i] # 부모 노드에 현재 인덱스의 크기 곱하기
        i -= 1
    
setTree(tree_size - 1) # 초기 세그먼트 트리 세팅하기

# 값 변경 함수!
def changeVal(index, value):
    diff = value/tree[index] # 현재 노드의 값과 변경된 값의 차이
    while (index > 0):
        tree[index] = tree[index] * diff
        index = index // 2 # 현재 노드부터 최고 부모 노드까지 쭉 손보기

# 구간곱 계산 함수!
def getMul(s,e):
    part_mul = 1
    while(s <= e):
        if (s % 2 == 1): #시작인덱스 % 2 == 1일 때 해당 노드 선택
            part_mul *= tree[s] # 해당 노드의 값을 구간 합에 추가
            s += 1
        if (e % 2 == 0): #끝인덱스 % 2 == 0일 때 해당 노드 선택
            part_mul *= tree[e]
            e -= 1
        
        s //= 2 #시작인덱스 = (시작인덱스 + 1) / 2 (시작인덱스 깊이 변경)
        e //= 2 #끝인덱스 = (끝인덱스 - 1) / 2 (끝인덱스 깊이 변경)

    return part_mul

# 결과
for _ in range(M + K):
    ques,s,e = map(int, stdin.readline().split())

    if (ques == 1): # 값 변경
        s += left_node_start_index
        changeVal(s,e) #여기서 e는 val
    else: # 구간합 계산
        s += left_node_start_index #트리에서 원래 데이터 노드들은 리프에 위치해 있어 left~ 값을 더해주어야 함 주의
        e += left_node_start_index
        print(int(getMul(s,e)))

아 좀만 바꾸면 되네~ 껌이다~

pypy로 햇더니.. 

0이 있으면 어쩌지 음

 

문제 하나 더 있었음..

 

문제 풀이에서 1000000007, 1000000009으로 나누는 이유가 뭐지? (feat.모듈로 연산)

문제 풀이에서 1000000007, 1000000009으로 나누는 이유가 뭐지? (feat.모듈로 연산) 가끔 알고리즘 풀다보면 뭐 숫자를 몇으로 나눈 나머지를 써라~~ 이런 것들이 보인다. 아니 이거 왜씀? 하기도 하고,

hello-backend.tistory.com

 

❗️오버플로우 방지를 위해 지속적인 mod연산 필요

❗️0값 오류 방지를 위해 갱신시 다시 처음부터 계산하도록!!

from sys import stdin
N,M,K= map(int, stdin.readline().split()) # 수 개수 , 변경 횟수, 구간합 구하기 횟수

MOD = 1000000007 # 모듈러연산!!
tree_height = 0
length = N

#트리의 높이 구하기 (k 구하기)
while (length != 0):
    length //= 2 # 리프 노드의 개수를 2씩 나누어 가면서 높이 계산
    tree_height += 1

tree_size = pow(2,tree_height+1) # 트리리스트 크기 구하기
tree = [1] * (tree_size + 1) # 세그먼트 트리는 1부터 시작하는 것이 계산에 편리
left_node_start_index = tree_size // 2 -1 # 리프 노드 시작 인덱스

# 데이터 리프 노드 저장하기
for i in range(left_node_start_index + 1, left_node_start_index + N + 1):
    tree[i] = int(stdin.readline())

# 세그먼트 트리 생성 함수!
def setTree(i):
    while(i != 1):
        tree[i // 2] = tree[i // 2] * tree[i] % MOD # 부모 노드에 현재 인덱스의 크기 곱하기 // 모듈러연산!!오버플로우 방지
        i -= 1
    
setTree(tree_size - 1) # 초기 세그먼트 트리 세팅하기

# 값 변경 함수! 
# 0 주의 ! 아래부터 다시 전부 계산하기
def changeVal(index, value):
    tree[index] = value
    while (index > 1):
        index = index // 2 # 변경 리프 노드부터 최고 부모 노드까지 쭉 손보기
        tree[index] = (tree[index*2]%MOD) * (tree[index*2 + 1]%MOD) # 두 자식 노드 계산 / MOD 연산 주의

# 구간곱 계산 함수!
def getMul(s,e):
    part_mul = 1
    while(s <= e):
        if (s % 2 == 1): #시작인덱스 % 2 == 1일 때 해당 노드 선택
            part_mul = part_mul * tree[s] % MOD # 모듈러연산!
            s += 1
        if (e % 2 == 0): #끝인덱스 % 2 == 0일 때 해당 노드 선택
            part_mul = part_mul * tree[e] % MOD # 모듈러연산!
            e -= 1
        
        s //= 2 #시작인덱스 = (시작인덱스 + 1) / 2 (시작인덱스 깊이 변경)
        e //= 2 #끝인덱스 = (끝인덱스 - 1) / 2 (끝인덱스 깊이 변경)

    return part_mul

# 결과
for _ in range(M + K):
    ques,s,e = map(int, stdin.readline().split())

    if (ques == 1): # 값 변경
        s += left_node_start_index
        changeVal(s,e) #여기서 e는 val
    else: # 구간합 계산
        s += left_node_start_index #트리에서 원래 데이터 노드들은 리프에 위치해 있어 left~ 값을 더해주어야 함 주의
        e += left_node_start_index
        print(getMul(s,e) % MOD) # mod연산 주의

ㅜㅠ.. 혼자 푸는 능력 부족