AtCoder Beginner Contest 297 Ex - Diff Adjacent

最後の一手で詰まっていた……。

問題へのリンク

問題

整数 $N$ が与えられる。

  • 要素の総和が $N$
  • 隣り合う項の値が異なる

を満たすような数列の長さの総和を $998244353$ で割った余りを求めよ。

前提

とくになし

解法

問題の第一印象は

  • 「$N-1$ 個の条件を全て満たす数列を数え上げる」問題なので、包除が使えそう
  • 長さを足し上げろと言われているので、数列の長さを変数の肩に乗せた形式的冪級数を考えて、最後に微分すればよさそう

だったので、考察の方針は FPS にロックオンした。

項が隣接している部分は $N - 1$ 箇所あるので、条件に違反しているもの(つまり、等しくなっているもの)の決め打ち方 $2 ^ {N - 1}$ 通りに対して包除原理を適用する。

図に示したとおり、条件に違反する隣接部分を決め打つことは、列を等しい値の区間で分割することと同じだと思える。

よって、長さ $ M $ の列の個数は

$$ [x ^ N] \sum _ { K = 1 } ^ \infty \sum _ {c _ 1 + \dots + c _ K = M} \prod _ {i = 1} ^ K (-1) ^ {c _ i - 1} (x ^ {c_i} + x ^ {2c_i} + \dots ) $$

と書ける。$ c _ 1 + \dots + c _ K = M $ の部分を表現するために変数 $y$ を導入すると、上式は

$$ [x ^ N y ^ M] \sum _ { K = 1 } ^ \infty \Brace{ \sum _ {c = 1} ^ \infty \frac{-(-xy) ^ c}{1 - x ^ c} } ^ K = [x ^ N y ^ M] \frac{1}{1 + \sum _ {c = 1} ^ \infty \frac{(-xy) ^ c}{1 - x ^ c}} $$

と整理できる。

求めたいものは長さの総和だったので、$ y $ の指数を微分で下ろしてくればよい。また、 $ M $ を消すために全体を $1 - y$ で割って $y ^ N$ の係数を取ることにすると、求める値は

$$ [x ^ N y ^ N] \frac{y}{1 - y} \cdot \frac{\partial}{\partial y} \left( \frac{1}{1 + \sum _ {c = 1} ^ \infty \frac{(-xy) ^ c}{1 - x ^ c}} \right) $$

となる。めでたしめでたし……と思いきや、これを計算する方法が分からなくて困った……(ここでコンテスト終了)。

実は、$ 1 - y $ で割る代わりに、$ y $ に $ 1 $ を代入すると係数の和を取ることができる!

よって、

$$ [x ^ N] \left. \frac{\partial}{\partial y} \left( \frac{1}{1 + \sum _ {c = 1} ^ \infty \frac{(-xy) ^ c}{1 - x ^ c}} \right) \right| _ {y = 1} = [x ^ N] \frac{-\sum _ {c = 1} ^N \frac{c(-x) ^ c}{1 - x ^ c} }{\left(1 + \sum _ {c = 1} ^ N \frac{(-x) ^ c}{1 - x ^ c}\right) ^ 2} $$

が答えとなる。

$\frac{(-x) ^ c}{1 - x ^ c}$ のうち非零の係数を持つ項は $\Floor{N / c}$ 個しかないので、$ \sum _ {c = 1} ^ N \frac{(-x) ^ c}{1 - x ^ c} $ は $ O(N \log N) $ 時間で構築できる(調和数のやつ)。

ここまでくれば、あとは FPS ライブラリに投げるだけ。

コード

https://atcoder.jp/contests/abc297/submissions/40497449

FPS ライブラリはここ → rk-library/formal_power_series.hpp at master · Ricky-pon/rk-library · GitHub

int main() {
    int n;
    scanf("%d", &n);
 
    FormalPowerSeries<mint> num(n + 1), den(n + 1);
    For(k, 1, n + 1) {
        for (int i = 1; i * k <= n; ++i) {
            if (k % 2 == 1) {
                num[i * k] += k;
            } else {
                num[i * k] -= k;
            }
        }
    }
    den[0] = 1;
    For(k, 1, n + 1) {
        for (int i = 1; i * k <= n; ++i) {
            if (k % 2 == 1) {
                den[i * k] -= 1;
            } else {
                den[i * k] += 1;
            }
        }
    }
    den *= den;
    den.resize(n + 1);
    num /= den;
    printf("%u\n", num[n].val());
}

関連