読者です 読者をやめる 読者になる 読者になる

T2M_MJM

大学などで学んだことを真面目に書いていきます。

AtCoder Beginner Contest 046 D 問題

AtCoder

CTF と並行して、 AtCoder の問題によるアルゴリズムの勉強も習慣にしておこうと思っているので、先日解いた AtCoder Beginner Contest 046 D 問題について簡単な解答をまとめておきます。

abc046.contest.atcoder.jp

問題概要

グーとパーだけで行われ、かつグーを出した回数だけパーを出せるようなじゃんけんについての問題です。このじゃんけんで、相手の出す手が全てわかっている状態で、勝敗によって点数が1点ずつ上下する n 回勝負を行った場合に取りうる得点の最大値を求めます。

解答

答えから言ってしまうと、グーとパーを交互に出すことで得点が最大化されます。このじゃんけんは要するに「相手より何回パーを多く出したか」というのがそのまま得点数になり、したがって相手の出す手が何であれとにかくパーを多く出せば良いからです。

したがって、解答は以下のようなコードになります。

T = input()
A = ["gp"[i%2] for i in range(len(T))]
A = "".join(A)

point = 0
for t,a in zip(T,A):
    if t == "g" and a == "p":
        point += 1
    elif t == "p" and a == "g":
        point -= 1

print(point)

実際に解答に至るまでの諸々

これだけで話を終わらせてしまうのも味気ない(上に記録として全く意味を為さない)ので、実際に解答に至るまでにどのようなことをしたのかについて述べて記事を終わらせようと思います。

流れとしては、

  1. テストセットを生成するスクリプトを書く
  2. テストセットについて、(自分の手元で)得点数・それまでに出したグーとパーの数などの遷移を書き出す
  3. 最終的にパーを使い切っているものが得点を最大化する(ように見える)ことに気がつく

という手順で問題を解きました。数学的センスや知識があるともう少しきれいな手順で解けるかとも思います。勉強を積んだときにこの記事を読み返して自分で振り返ることが出来ると幸いです。

以下はテストセットを生成するスクリプトです。

import random
import itertools
import functools

def valid(testset):
    results = []
    for i in range(len(testset)+1):
        num_g = len(list(filter(lambda c: c is "g", testset[:i])))
        num_p = len(list(filter(lambda c: c is "p", testset[:i])))
        results.append(num_p <= num_g)
    return functools.reduce((lambda a,b: a & b), results)

def solve(testset):
    answers = []
    for answer in itertools.product("pg", repeat=len(testset)):
        if not valid(answer):
            continue
        point = 0
        for (T, A) in zip(testset, answer):
            if T == "g" and A == "p":
                point += 1
            elif T == "p" and A == "g":
                point -= 1
        answers.append((point, "".join(answer)))
    return max(answers)



testsets = []
for length in range(3,10):
    for i in range(100):
        testset = [("g","p")[random.randint(0,1)] for _ in range(length)]
        testset = "".join(testset)
        if testset not in testsets:
            testsets.append(testset)
testsets = filter(valid,testsets)

for t in testsets:
    print(t, solve(t))