문제


https://www.acmicpc.net/problem/10563


풀이


사실 꽤나 애먹었던 문제다. 맞왜틀의 향연

처음엔 lower bound와 upper bound만 있으면 부분 문제가 결정될 줄 알았다.

어느 숫자 x를 뽑았을 때 다음 턴에 1을 제거할 수 있지 않다면, x를 기준으로 나뉜 두 구간 중 1이 없는 구간의 길이의 홀짝만 반영해 주면 될 줄 알았다.

이 방식이 n = 5일 때까지는 반례가 없다.

하지만 n = 6일 때부터 반례가 존재한다.

[3, 6, 5, 1, 4, 2]

위와 같은 방식이면 Alice가 6을 제거하고, Bob이 3을 제거하고, Alice가 5를 제거하고, Bob이 4를 제거하고, Alice가 1을 제거하면서 이긴다.

그러나 Alice가 6을 제거한 후 Bob이 4를 제거하면 상황이 달라진다.

5를 제거하는 사람이 패배하는데, 그 외에 제외할 수 있는 수가 3, 2 두 개이므로 Alice가 5를 제거하게 되면서 Bob이 승리한다.

그러므로 flag라는 boolean 변수를 추가해 줘야한다.

flag가 0이라는 것은 해당 구간 외에 제거할 선택지가 없다(= 짝수개)는 것이고, flag 1이라는 것은 해당 구간 외에 제거할 선택지가 1개(= 홀수개) 존재한다는 것이다.

이것이 왜 필요할까? 다음과 같은 예시를 보자.

[5, 1, 4, 2]

flag가 0일 땐, 먼저 시작한 사람이 5를 제거하면 이긴다.

flag가 1일 땐, 먼저 시작한 사람이 4를 제거하면 이긴다.

이렇게 같은 구간이라도 ‘구간 외 제거할 것의 유무’가 승패를 나눈다.


코드


import sys


def f(i, j, flag):
    if dp[i][j][flag] != -1:
        return dp[i][j][flag]
    
    res = 0
    for k in range(i, j + 1):
        if ((k == i) or (num[k - 1] < num[k])) and ((k == j) or (num[k] > num[k + 1])):
            if k == idx:
                dp[i][j][flag] = 1
                return 1
            
            if k > idx:
                res |= (1 - f(i, k - 1, (flag + j - k) % 2))
            else:
                res |= (1 - f(k + 1, j, (flag + k - i) % 2))
    
    if flag:
        res |= (1 - f(i, j, 0))
    
    dp[i][j][flag] = res
    return res


for _ in range(int(sys.stdin.readline())):
    n = int(sys.stdin.readline())
    num = list(map(int, sys.stdin.readline().split()))
    idx = num.index(1)
    dp = [[[-1, -1] for j in range(n)] for i in range(n)]

    print(['Bob', 'Alice'][f(0, n - 1, 0)])


반례 찾기


이렇게 AC 코드를 알고 있고, WA 코드의 반례를 찾고 싶을 때 random한 데이터 셋을 만들어서 두 코드의 결과가 다를 때까지 비교해 보곤 한다.

from numpy import random


# AC function
def f(i, j, flag):
    ...


# WA function
def g(i, j):
    ...


while True:
    n = random.randint(1, 10)
    num = list(map(lambda x: x + 1, random.permutation(n)))
    idx = num.index(1)

    dp1 = [[[-1, -1] for j in range(n)] for i in range(n)]
    ans1 = f(0, n - 1, 0)

    dp2 = [[-1] * n for _ in range(n)]
    ans2 = g(0, n - 1)

    if ans1 != ans2:
        print(num)
        break

Leave a comment