Table of contents
Open Table of contents
들어가며
걸린 시간: 무한 (자력으로 풀지 못하고 솔루션을 보고 아이디어를 얻었습니다.)
사실 처음에 문제를 읽고, 오 이거 완전 탐색으로 풀면 되겠네 하고 접근을 했습니다.
오 웬걸 그런데 그렇게 접근을 하면 O(2^N * N^2) 이 나와서 이게 N = 20일 때, 4억정도가 나오게 되더라고요.
그래서 dp인가 하고 dp 점화식 세우다가 삽질을 하게 됐습니다.
접근
사실 제 접근이 맞았었습니다.
정확히는 O(2^N _ N^2)는 러프하게 나타낸 거고, 정확히 좀 더 시간복잡도를 계산하면 위의 식에서 N이 N - 1이 되고 상수로 나눠지고 대충 그래서, 4억이 아니라 대충 1억에 근접하게 됩니다.
그래서 pypy3로 돌리면 거의 시간초과에 걸릴 듯 말 듯 하게 돼서 AC를 받습니다.
이 문제에서 배운 점: 일단 계산했을 때 1억 근처면, 구현하고 생각해보자!! (최적화는 다음에 하면 되니)
근데 사실 이문제 O(2^N _ N)으로 최적화가 가능합니다.
위에 필기로 써놨는데, 매트릭스에서의 성질을 잘 이용하면 됩니다.
약간 수학 문제 같긴한데 (이게 실전에서 생각해 낼 수 있을지는 모르겠습니다 ㅎㅎ), 저희가 구해야 하는 게 L - S 이니깐,
(L + S + R) - (L + R + 2S) 임을 이용해서, 두 팀의 차를 O(N)만에 구할 수 있습니다.
구현
O(2^N * N^2)
import sys
from itertools import combinations
input = sys.stdin.readline
def calcDiff(start):
sum_start = sum_link = 0
start_combinations = combinations(start, 2)
link_combinations = combinations(numSet - set(start), 2)
for start_comb in start_combinations:
i, j = start_comb
sum_start += S[i][j] + S[j][i]
for link_comb in link_combinations:
i, j = link_comb
sum_link += S[i][j] + S[j][i]
return abs(sum_start - sum_link)
if __name__ == "__main__":
N = int(input().strip())
S = [[*map(int, input().strip().split())] for _ in range(N)]
numSet = set(range(N))
ans = int(1e9)
for i in range(1, N // 2 + 1):
combinationCandsGen = combinations(range(N), i)
for combinationCand in combinationCandsGen:
ans = min(ans, calcDiff(combinationCand))
print(ans)
O(2^N * N)
import sys
from itertools import combinations
input = sys.stdin.readline
def extractElements(bitmask):
ret = []
for i in range(N):
if bitmask % 2:
ret.append(i)
bitmask //= 2
return ret
def getNumElements(bitmask):
ret = 0
while bitmask:
if bitmask % 2:
ret += 1
bitmask //= 2
return ret
def calcDiff(start):
"""
O(N)
"""
tmp = 0
for e in extractElements(start):
tmp += sumRowList[e] + sumColList[e]
return abs(sumMatrix - tmp)
if __name__ == "__main__":
N = int(input().strip())
S = [[*map(int, input().strip().split())] for _ in range(N)]
numSet = 2 ** N - 1
sumRowList = [sum(S[i]) for i in range(N)]
sumColList = [sum(col) for col in zip(*S)]
sumMatrix = sum(sumRowList)
ans = int(1e9)
for i in range(1, 2 ** N):
numElements = getNumElements(i)
if numElements > (N // 2):
continue
ans = min(ans, calcDiff(i))
print(ans)