문제


https://www.acmicpc.net/problem/2365


풀이


아이디어를 떠올리는 것부터 매개 변수 탐색까지 전체적으로 난이도가 있었다.

우선 각 행과 열을 노드에 대응시켜, 행들은 sourse에 연결하고 열들은 sink에 연결한다.

이때 각 엣지의 capacity는 각 행, 열의 합이 된다.

그리고 행 노드들과 열 노드들을 잇는 간선들의 capacity를 매개 변수 탐색으로 찾아준다.

맨 처음에 그래프를 비효율적으로 만들어서 계속 시간초과가 떴다.

이후엔 매개 변수 탐색 구현 때문에 계속 틀리다가 겨우 맞췄다.

Ford-Fulkerson 알고리즘으로 풀었을 땐 pypy3 기준 9340ms가 나왔는데, Edmonds-Karp 알고리즘을 적용시키니 10분의 1 수준인 800ms까지 줄어들었다.


코드


import sys
from collections import deque

n = int(sys.stdin.readline())
row = list(map(int, sys.stdin.readline().split()))
col = list(map(int, sys.stdin.readline().split()))
row_sum = sum(row)

left = 0
right = 10000
while True:
    mid = (left + right) // 2
    
    graph = [[0] * (2 * n + 2) for _ in range(2 * n + 2)]

    for i in range(n):
        graph[0][i + 1] = row[i]
        graph[n + i + 1][2 * n + 1] = col[i]

    for i in range(1, n + 1):
        for j in range(n + 1, 2 * n + 1):
            graph[i][j] = mid

    total = 0
    while True:
        prev = [-1] * (2 * n + 2)
        queue = deque([(0, float('inf'))])
        while queue:
            cur, res = queue.popleft()

            for nxt in range(1, 2 * n + 2):
                if (nxt == 2 * n + 1) and graph[cur][-1]:
                    res = min(res, graph[cur][-1])
                    prev[-1] = cur
                    break

                if (prev[nxt] == -1) and graph[cur][nxt]:
                    prev[nxt] = cur
                    queue.append((nxt, min(res, graph[cur][nxt])))

            if prev[-1] != -1:
                break
        
        if prev[-1] == -1:
            break

        total += res
        while True:
            cur = prev[nxt]
            graph[cur][nxt] -= res
            graph[nxt][cur] += res
            nxt = cur
            if nxt == 0:
                break

    if total < row_sum:
        left = mid + 1
        if left > right:
            right = left
    else:
        right = mid - 1

        if left > right:
            print(mid)
            for j in range(1, n + 1):
                for i in range(n + 1, 2 * n + 1):
                    print(graph[i][j], end=' ')
                print()
            break

Leave a comment