Skip to content

BOJ 백준 15661: 링크와 스타트

Updated: at 오전 12:10

Table of contents

Open Table of contents

들어가며

걸린 시간: 무한 (자력으로 풀지 못하고 솔루션을 보고 아이디어를 얻었습니다.)

사실 처음에 문제를 읽고, 오 이거 완전 탐색으로 풀면 되겠네 하고 접근을 했습니다.
오 웬걸 그런데 그렇게 접근을 하면 O(2^N * N^2) 이 나와서 이게 N = 20일 때, 4억정도가 나오게 되더라고요.
그래서 dp인가 하고 dp 점화식 세우다가 삽질을 하게 됐습니다.

접근

사실 제 접근이 맞았었습니다.
정확히는 O(2^N _ N^2)는 러프하게 나타낸 거고, 정확히 좀 더 시간복잡도를 계산하면 위의 식에서 N이 N - 1이 되고 상수로 나눠지고 대충 그래서, 4억이 아니라 대충 1억에 근접하게 됩니다.
그래서 pypy3로 돌리면 거의 시간초과에 걸릴 듯 말 듯 하게 돼서 AC를 받습니다.
이 문제에서 배운 점: 일단 계산했을 때 1억 근처면, 구현하고 생각해보자!! (최적화는 다음에 하면 되니) 근데 사실 이문제 O(2^N _ N)으로 최적화가 가능합니다.
위에 필기로 써놨는데, 매트릭스에서의 성질을 잘 이용하면 됩니다.
약간 수학 문제 같긴한데 (이게 실전에서 생각해 낼 수 있을지는 모르겠습니다 ㅎㅎ), 저희가 구해야 하는 게 L - S 이니깐, (L + S + R) - (L + R + 2S) 임을 이용해서, 두 팀의 차를 O(N)만에 구할 수 있습니다.

구현

O(2^N * N^2)

import sys
from itertools import combinations
input = sys.stdin.readline

def calcDiff(start):
    sum_start = sum_link = 0

    start_combinations = combinations(start, 2)
    link_combinations = combinations(numSet - set(start), 2)

    for start_comb in start_combinations:
        i, j = start_comb
        sum_start += S[i][j] + S[j][i]

    for link_comb in link_combinations:
        i, j = link_comb
        sum_link += S[i][j] + S[j][i]

    return abs(sum_start - sum_link)


if __name__ == "__main__":
    N = int(input().strip())
    S = [[*map(int, input().strip().split())] for _ in range(N)]
    numSet = set(range(N))
    ans = int(1e9)
    for i in range(1, N // 2 + 1):
        combinationCandsGen = combinations(range(N), i)
        for combinationCand in combinationCandsGen:
            ans = min(ans, calcDiff(combinationCand))

    print(ans)

O(2^N * N)

import sys
from itertools import combinations
input = sys.stdin.readline

def extractElements(bitmask):
    ret = []
    for i in range(N):
        if bitmask % 2:
            ret.append(i)
        bitmask //= 2

    return ret

def getNumElements(bitmask):
    ret = 0
    while bitmask:
        if bitmask % 2:
            ret += 1
        bitmask //= 2

    return ret


def calcDiff(start):
    """
    O(N)
    """
    tmp = 0
    for e in extractElements(start):
        tmp += sumRowList[e] + sumColList[e]
    return abs(sumMatrix - tmp)


if __name__ == "__main__":
    N = int(input().strip())
    S = [[*map(int, input().strip().split())] for _ in range(N)]
    numSet = 2 ** N - 1
    sumRowList = [sum(S[i]) for i in range(N)]
    sumColList = [sum(col) for col in zip(*S)]
    sumMatrix = sum(sumRowList)
    ans = int(1e9)
    for i in range(1, 2 ** N):
        numElements = getNumElements(i)
        if numElements > (N // 2):
            continue
        ans = min(ans, calcDiff(i))

    print(ans)