AtCoder Beginner Contest 212 H - Nim Counting

H は Hadamard transform の H

問題へのリンク

問題

整数 $ N $ と、整数 $ A_1, A_2, \dots, A_K $ が与えられる。

$ 1 \leq M \leq N $ かつ $ d_i \in \Brace{A_1, A_2, \dots, A_K} $ を満たす長さ $ M $ の整数列 $ d $ のうち、 $$ d_1 \oplus d_2 \oplus \dots \oplus d_M \neq 0 $$ となるものの個数を $998244353$ で割った余りを求めよ。

前提

ここではアダマール変換をブラックボックスとして扱うので、定義は無理に覚えようとしなくてもよい。

定義 $ 1 $(アダマール変換)
$ 2 ^ n \times 2 ^ n $ 行列 $ H _ n $ を $$ \begin{split} H _ 0 &= 1 \\\ H _ {n} &= \begin{bmatrix} H _ {n - 1} & H _ {n - 1} \\ H _ {n - 1} & -H _ {n - 1} \\ \end{bmatrix} \end{split} $$ と定める。$ H_n $ が定める変換をアダマール変換という。また、$ H _ n $ をアダマール行列という。
定理 $ 2 $(逆変換)
アダマール行列 $ H _ n $ に対し、 $$ H _ n H _ n = 2 ^ n I $$ が成り立つ。

ここからは重要。

定理 $ 3 $
$ H _ n $ をアダマール行列、$ a, b $ を長さ $ 2 ^ n $ のベクトル、$ c $ を定数とすると $$ \begin{split} H _ n (ca) &= c H _ n a \\\ H _ n (a + b) &= H _ n a + H _ n b \end{split} $$ が成り立つ。
定理 $ 4 $(畳み込み定理)
$ H _ n $ をアダマール行列、$ a, b $ を長さ $ 2 ^ n $ のベクトルとすると $$ H _ n (a * b) = (H _ n a)(H _ n b) ^ \top $$ が成り立つ。 ただし、$ (a * b) _ k = \sum _ {i \oplus j = k} a _ i b _ j $(つまり、xor 畳み込み)である。
定理 $ 5 $(高速アダマール変換)
$ a $ を長さ $ 2 ^ n $ のベクトルとすると、$ H _ n a $ は $ O(n2 ^ n) $ 時間で計算できる。

解法

$ n = 16 $、$ A = 2 ^ {n} ( > \max A_i ) $ とする。

$ \mathrm{dp}(i, j) := (i$ 個目の山まで見て、総 xor が $ j $ であるような初期状態の個数$)$ とおいて DP したくなるが、状態数 $ O(NA) $、遷移 $ O(A) $ 時間になって困る。

xor 畳み込みを用いて $ \mathrm{dp}(i-1, * ) $ から $ \mathrm{dp}(i, * )$ をまとめて求めるようにすると $ O(NA\log A) $ 時間になるが、まだ遅い。

$ f $ を $$ f _ i = \begin{cases} 1 & i \in \Brace{A_1, A_2, \dots, A_K} \\ 0 & \text{otherwise} \end{cases} $$ を満たす長さ $ A $ のベクトルとする。求めたいものは $ f + f * f + \dots + \underbrace{f * f * \dots * f} _ n $ である。順番に計算していると上と同じで $ O(NA\log A) $ かかってしまうので、一気にアダマール変換することを考える。

$$ \begin{split} (H _ n (f + f * f + \dots + \underbrace{f * f * \dots * f} _ n))_k &= (H _ n(f)) _ k + (H _ n(f * f)) _ k + \dots + (H _ n(\underbrace{f * f * \dots * f} _ n)) _ k \\ &= \sum _ {i = 1} ^ n ( ( H _ n f ) _ k ) ^ i \end{split} $$

であり、等比数列の和の公式を使うと $ 2 $ 行目が $ O( \log N \log mod ) $ で求まる。$ 1 $ つ目のイコールは定理 $ 4 $ から、$ 2 $ つ目のイコールは定理 $ 5 $ から従う。

アダマール変換、逆アダマール変換は $ O(A \log A) $ 時間でできるので、全体で $ O(A (\log A + \log N \log mod)) $ 時間で答えを得ることができる。

コード

https://atcoder.jp/contests/abc212/submissions/38703684

template <typename T>
void fast_walsh_hadamard_transform(vector<T>& f) {
    int n = f.size();
    for (int i = 1; i < n; i <<= 1) {
        for (int j = 0; j < n; j++) {
            if ((j & i) == 0) {
                T x = f[j], y = f[j | i];
                f[j] = x + y, f[j | i] = x - y;
            }
        }
    }
}
template <typename T>
void inverse_fast_walsh_hadamard_transform(vector<T>& f) {
    int n = f.size();
    for (int i = 1; i < n; i <<= 1) {
        for (int j = 0; j < n; j++) {
            if ((j & i) == 0) {
                T x = f[j], y = f[j | i];
                f[j] = (x + y) / 2, f[j | i] = (x - y) / 2;
            }
        }
    }
}
 
int main() {
    int n, K;
    scanf("%d%d", &n, &K);
 
    vector<mint> f(1 << 16, 0);
    rep(i, K) {
        int a;
        scanf("%d", &a);
        f[a] = 1;
    }
 
    fast_walsh_hadamard_transform(f);
    for (auto& x : f) {
        if (x.val() == 0) {
            x = 0;
        } else if (x.val() == 1) {
            x = n;
        } else {
            x = (x.pow(n + 1) - x) / (x - 1);
        }
    }
    inverse_fast_walsh_hadamard_transform(f);
 
    printf("%u\n", accumulate(f.begin() + 1, f.end(), mint(0)).val());
}

関連