AtCoder Beginner Contest 212 H - Nim Counting
H は Hadamard transform の H
問題
整数 $ N $ と、整数 $ A_1, A_2, \dots, A_K $ が与えられる。
$ 1 \leq M \leq N $ かつ $ d_i \in \Brace{A_1, A_2, \dots, A_K} $ を満たす長さ $ M $ の整数列 $ d $ のうち、 $$ d_1 \oplus d_2 \oplus \dots \oplus d_M \neq 0 $$ となるものの個数を $998244353$ で割った余りを求めよ。
前提
ここではアダマール変換をブラックボックスとして扱うので、定義は無理に覚えようとしなくてもよい。
ここからは重要。
解法
$ n = 16 $、$ A = 2 ^ {n} ( > \max A_i ) $ とする。
$ \mathrm{dp}(i, j) := (i$ 個目の山まで見て、総 xor が $ j $ であるような初期状態の個数$)$ とおいて DP したくなるが、状態数 $ O(NA) $、遷移 $ O(A) $ 時間になって困る。
xor 畳み込みを用いて $ \mathrm{dp}(i-1, * ) $ から $ \mathrm{dp}(i, * )$ をまとめて求めるようにすると $ O(NA\log A) $ 時間になるが、まだ遅い。
$ f $ を $$ f _ i = \begin{cases} 1 & i \in \Brace{A_1, A_2, \dots, A_K} \\ 0 & \text{otherwise} \end{cases} $$ を満たす長さ $ A $ のベクトルとする。求めたいものは $ f + f * f + \dots + \underbrace{f * f * \dots * f} _ n $ である。順番に計算していると上と同じで $ O(NA\log A) $ かかってしまうので、一気にアダマール変換することを考える。
$$ \begin{split} (H _ n (f + f * f + \dots + \underbrace{f * f * \dots * f} _ n))_k &= (H _ n(f)) _ k + (H _ n(f * f)) _ k + \dots + (H _ n(\underbrace{f * f * \dots * f} _ n)) _ k \\ &= \sum _ {i = 1} ^ n ( ( H _ n f ) _ k ) ^ i \end{split} $$
であり、等比数列の和の公式を使うと $ 2 $ 行目が $ O( \log N \log mod ) $ で求まる。$ 1 $ つ目のイコールは定理 $ 4 $ から、$ 2 $ つ目のイコールは定理 $ 5 $ から従う。
アダマール変換、逆アダマール変換は $ O(A \log A) $ 時間でできるので、全体で $ O(A (\log A + \log N \log mod)) $ 時間で答えを得ることができる。
コード
https://atcoder.jp/contests/abc212/submissions/38703684
template <typename T> void fast_walsh_hadamard_transform(vector<T>& f) { int n = f.size(); for (int i = 1; i < n; i <<= 1) { for (int j = 0; j < n; j++) { if ((j & i) == 0) { T x = f[j], y = f[j | i]; f[j] = x + y, f[j | i] = x - y; } } } } template <typename T> void inverse_fast_walsh_hadamard_transform(vector<T>& f) { int n = f.size(); for (int i = 1; i < n; i <<= 1) { for (int j = 0; j < n; j++) { if ((j & i) == 0) { T x = f[j], y = f[j | i]; f[j] = (x + y) / 2, f[j | i] = (x - y) / 2; } } } } int main() { int n, K; scanf("%d%d", &n, &K); vector<mint> f(1 << 16, 0); rep(i, K) { int a; scanf("%d", &a); f[a] = 1; } fast_walsh_hadamard_transform(f); for (auto& x : f) { if (x.val() == 0) { x = 0; } else if (x.val() == 1) { x = n; } else { x = (x.pow(n + 1) - x) / (x - 1); } } inverse_fast_walsh_hadamard_transform(f); printf("%u\n", accumulate(f.begin() + 1, f.end(), mint(0)).val()); }