挿入 DP の二つの型について

問題のネタバレを多く含みます

挿入 DP

隣り合う項の大小関係に条件がついた順列を数え上げるときに使われることが多い。

$\mathrm{dp}(i, S) := $ $(1, 2, \dots, i$ からなる順列のうち、$S$ をみたすものの個数$)$

という状態を考えるのだが、遷移の考え方によって二種類の型に分けられる(と考えている)ので、それぞれについて説明する。

その①「割り込み」型

前から作っていくのではなく、遷移の際に「$1, 2, \dots, i - 1$ からなる順列のどこに $i$ を挿入するか」を考える。

DEGwer さんの数え上げ PDF で紹介されているのはこっち。

超簡単な例題

概要

$1$ 以上 $N$ 以下の整数からなる順列の個数を求めよ。

制約

$ 1 \leq N \leq 10 ^ 5 $

解法

答えが $N!$ になることは知らないとして、DP を考える。

$\mathrm{dp}(i) := $ $(1, 2, \dots, i$ を並べた順列の個数$)$ と定める。

$1, 2, \dots, i$ からなる順列において $i + 1$ を挿入できる箇所は $i + 1$ 通りあるので、

$$ \mathrm{dp}(i + 1) = \mathrm{dp}(i) \times (i + 1) $$

となる。よって、$O(N)$ 時間の DP で答えが求まる。

ABC 267 G - Increasing K Times

atcoder.jp

概要

与えられた数列 $A$ を並べ替えて得られる数列 $B$ であって、$B_i < B_{i + 1}$ を満たす $i$ が $K$ 個あるようなものの個数を求めよ。(ただし、$A$ に含まれる数は値が同じであっても区別する)

制約

$2 \leq N \leq 5000$

解法

kyopro_friends さんの解説と同じ。

値の小さい順に、値が同じものはまとめて、列に挿入していくことを考える。 列を構成している数たちは、今挿入しようとしている数よりも小さいため、

$\mathrm{dp}(i, j) := $ 小さい方から $i$ 種類の数を挿入したとき、$B_k < B_{k + 1}$ なる $k$ が $j$ 個あるような列 $B$ の個数

という DP が回る。

数え上げ PDF 3.2 節

概要

高さ $1$ から $N$ までのビルを横一列に並べる。左からは $K$ 個、右からは $L$ 個のビルが見えるような並べ方は何通りあるか求めよ。

簡単な解説

さっきとは逆に、大きい方(高い方)から順に挿入していくと見通しがよい。

その②「過去改変」型

前から順に列を作っていく。ただし、$p_i$ までを作る際に使えるのは $1, 2, \dots, i$ のみであるとする。

$p_{i + 1} = x$ なる $x$ を $1, 2, \dots, i+1$ から選びたいのだが、このままでは列に $x$ が $2$ つ出現してしまう。そこで、さっきまでの列における $x, x + 1, \dots, i$ は $1$ つずつ後ろにずらして、$x + 1, x + 2, \dots, i + 1$ だったことにする

例えば、$p_6$ までを $1$ 〜 $6$ で作った列が $(3, 2, 4, 1, 5, 6)$ だったとする。ここで $p_7 = 4$ と決めると、新しい数列は $(3, 2, 5, 1, 6, 7, 4)$ となる。$p_3 = 4, p_6 = 5, p_7 = 6$ が $1$ つずつ後ろにずれて、$p_3 = 5, p_6 = 6, p_7 = 7$ だったことになっている。

型①とは異なり今作っている列に挿入するわけではないので分かりにくいが、使った数の集合 $1, 2, \dots, i$ に $k$ を挿入して $x\in \lbrace 1, \dots, k-1, k, k + 1, \dots, i + 1 \rbrace$ としていると思えるので、「挿入 DP」と呼ばれるのも納得できる。

超簡単な例題

概要

$1$ 以上 $N$ 以下の整数からなる順列の個数を求めよ。

制約

$ 1 \leq N \leq 10 ^ 5 $

解法

$\mathrm{dp}(i) := (i $ 項目までに $1, 2, \dots, i$ を並べた順列の個数$)$ と定める。

$i + 1$ 項目としてありうる値は $1, 2, \dots, i + 1$ の $i + 1$ 通りあるので、

$$ \mathrm{dp}(i + 1) = \mathrm{dp}(i) \times (i + 1) $$

となる。よって、$O(N)$ 時間の DP で答えが求まる。

EDPC T - Permutation

atcoder.jp

概要

<, >, ? からなる長さ $N - 1$ の文字列 $s$ が与えられる。

$1$ 以上 $N$ 以下の整数からなる順列 $p$ であって、

  • $s_i = $< ならば $p_i < p_ {i + 1}$
  • $s_i = $> ならば $p_i > p_ {i + 1}$

を満たすものの個数を求めよ。

制約

$2 \leq N \leq 3000$

解法

$\mathrm{dp}(i, j) := $ $( i$ 項目までに $1, 2, \dots, i$ を並べた順列であって、$p_i = j$ であり、不等号の制約を満たすものの個数$)$ と定める。

$p_{i + 1}$ に何を置けるかは、$j$ と $s_i$ を見れば分かる。

  • $s_i =$ < のときは $p_{i + 1}$ として $j, j+1, \dots, i+1$ を選べる。
  • $s_i$ = > のときは $p_{i + 1}$ として $1, 2, \dots, j$ を選べる。$j$ が被っているのが不思議だが、元あった $j$ は $j+1$ だったことになるので問題ない。

ABC 282 G - Similar Permutation

atcoder.jp

概要

$1$ 以上 $N$ 以下の整数からなる順列 $A, B$ であって、隣接する項の大小が $K$ 箇所で一致しているようなものの個数を求めよ。

簡単な解説

列を $2$ つ作る必要があるが、EDPC T - Permutation とほぼ同じ考え方で解ける。詳細は公式解説を参照。

数え上げ PDF 第 5 章 - 問題 1

概要

$1$ 以上 $N$ 以下の整数からなる順列であって、単調増加な $2$ つの列のマージとして書けるものの個数を求めよ。

簡単な解説

「greedy からの帰着」を行ったあとの数え上げパートで、型②の挿入 DP をする。

$\mathrm{dp}(i, j) := $ $( i$ 項目までに $1, 2, \dots, i$ を並べた順列であって、$B$ の列の末尾が $j$ であるようなものの個数$)$

とすると解ける。

AtCoder で青になるまでにやったこと

AtCoder で青になるまでにやったことを紹介します。

AtCoder で黄色になる

ARC に出る

冷える

いかがでしたか?

AtCoder Beginner Contest 295 Ex - E or m

E or m よりは、三 or 川 という感じがする

問題へのリンク

問題

各マスが白、黒、灰色に塗られた $ N \times M $ のグリッドが与えられる。 各灰マスを白か黒に塗り替えてできるグリッドのうち、以下のようにして生成できるものはいくつあるか?

  • グリッドの全てのマスを白で初期化する。
  • 各 $ i = 1, \dots, N $ に対して、$ i $ 行目のマスのうち左からいくつかを黒に塗る。
  • 各 $ j = 1, \dots, M $ に対して、$ j $ 列目のマスのうち上からいくつかを黒に塗る。

制約

  • $ 1 \leq N, M \leq 18 $

前提

とくになし

解法

まずは生成可能な塗り方を観察してみる。

OK            NG
1010  1111    0100  0111
1111  1001    1110  1111
1110  1110    0110  0011
0010  1000    0001  0010

すると、ある塗り方が生成可能であることの必要十分条件

「全ての黒マスについて、そのマスの上全部か左全部が黒で塗られている」

ことだと分かる。

よって、

$ \mathrm{dp}(i, j, f, S) := ($マス $ (i, j) $ まで塗って、左を全て黒で塗っているか否かが $f$ で、上を全て塗っているマスの集合が $S$ であるような塗り方の数$)$

として DP すればよい。時間計算量は $ O(NM2 ^ M) $。

コード

https://atcoder.jp/contests/abc295/submissions/40530818

int main() {
    int n, m;
    scanf("%d%d", &n, &m);
    char c[n][m];
    rep(i, n) rep(j, m) scanf(" %c", &c[i][j]);
 
    vector<vector<vector<mint>>> dp(
        m + 1, vector<vector<mint>>(2, vector<mint>(1 << m, 0)));
    dp[m][1][(1 << m) - 1] = 1;
    rep(i, n) {
        // init
        vector<vector<vector<mint>>> nxt(
            m + 1, vector<vector<mint>>(2, vector<mint>(1 << m, 0)));
        rep(f, 2) rep(S, 1 << m) { nxt[0][1][S] += dp[m][f][S]; }
        dp.swap(nxt);
 
        // dp
        rep(j, m) rep(f, 2) rep(S, 1 << m) {
            if (dp[j][f][S] == 0) continue;
 
            rep(col, 2) {
                if (c[i][j] != '?' && col != c[i][j] - '0') continue;
                if (col == 1 && !f && !(S >> j & 1)) {
                    continue;
                }
 
                int nf = f & col;
                int T = S;
                if (col == 0 && (S >> j & 1)) {
                    T -= (1 << j);
                }
                dp[j + 1][nf][T] += dp[j][f][S];
            }
        }
    }
 
    mint ans = 0;
    rep(f, 2) rep(S, 1 << m) ans += dp[m][f][S];
    printf("%u\n", ans.val());
}

関連

AtCoder Beginner Contest 296 Ex - Unite

グリッドマンユニバースを見に行かないといけない

問題へのリンク

問題

白黒で塗られた $ N \times M $ のグリッドがある。 黒マス全体を連結にするために、追加で黒に塗るべきマスは最小でいくつ?

制約

  • $ 1 \leq N \leq 100 $
  • $ 1 \leq M \leq 7 $ ← 小さい!

前提

とくになし

解法

「連結性 DP」「面倒 DP」などと呼ばれているやつ。

$ \mathrm{dp}(i, S) := ( $$ i $ 行目まで決めて、連結性が $ S $ であるときに塗るべきマスの個数の最小値$)$

のようにして DP する。遷移を計算する際には $ i $ 行目の白黒のパターン $ 2 ^ M $ 通りを全探索する。

連結性について詳しく説明する。

連結性を数列で表現する様子

例えば、上の図において左の二つの塗り方は同じ状態に圧縮できる。一番下の行はどちらも 黒白黒黒白黒白 であり、$ 1 $ 列目と $ 6 $ 列目、$ 3 $ 列目と $ 4 $ 列目の黒マスがそれぞれ連結になっている。よって、白マスには $ 0 $、黒マスには左から順に連結成分番号を振っていくことにすると、二つの塗り方はいずれも $ (1, 0, 2, 2, 0, 1, 0) $ という数列で表現できる。

実装が大変そうだが、私はいつも $ i - 1 $ 行目と $ i $ 行目を合わせた $ 2M $ 頂点のグラフに対する Union-Find 木を用意し、

  1. $ i - 1 $ 行目において、連結成分番号が同じ黒マスをマージ
  2. $ i $ 行目の各黒マスと隣接する黒マスをマージ
  3. 長さ $ M $ の配列 $ a $ を用意し、$ a _ j $ には $ (i, j) $ が白なら $ 0 $、黒なら $ (i, j) $ が属する集合の代表元 $ + 1 $ を入れる
  4. map などを持ちながら $ a $ を順番に舐めて、連結成分番号を $ 1 $ から振り直す

のようにしている。

また、今回は $ i , S $ に加えて、これ以上黒マスが登場しないかを状態として持つ(耳 DP)。

状態を $8$ 進数で表現して int 型に詰め込むと高速化できる。$ 2 ^ k $ 進数にすると、値をビット演算で取り出せるので都合が良い。

コード

遷移を計算する部分で $ N $ と $ M $ を取り違えないように、$ N $ は main 内、$ M $ はグローバルで宣言している。

https://atcoder.jp/contests/abc296/submissions/40333675

int m;
 
int next_phase(vector<int> &v, int phase, int S) {
    if (phase == 0) {
        return (S == 0 ? 0 : 1);
    } else if (phase == 1) {
        return (S == 0 ? 2 : 1);
    } else {
        return (S == 0 ? 2 : -1);
    }
}
 
optional<pair<vector<int>, int>> next_state(pair<vector<int>, int> state,
                                            int S) {
    auto [v, phase] = state;
    auto np = next_phase(v, phase, S);
    if (np == -1) return nullopt;
    if (S == 0) {
        if (*max_element(v.begin(), v.end()) >= 2) {
            return nullopt;
        } else {
            vector<int> res(m, 0);
            return make_pair(res, np);
        }
    }
 
    atcoder::dsu uf(m * 2);
    rep(i, m) For(j, i + 1, m) {
        if (v[i] != 0 && v[i] == v[j]) uf.merge(i, j);
    }
    rep(i, m) {
        if (!(S >> i & 1)) continue;
        if (i + 1 < m && (S >> (i + 1) & 1)) {
            uf.merge(i + m, i + 1 + m);
        }
        if (v[i] != 0) {
            uf.merge(i, i + m);
        }
    }
    rep(i, m) {
        if (v[i] == 0) continue;
        bool flag = false;
        rep(j, m) {
            if (uf.same(i, j + m)) {
                flag = true;
                break;
            }
        }
        if (!flag) return nullopt;
    }
 
    vector<int> nv(m, 0);
    rep(i, m) {
        if (!(S >> i & 1)) continue;
        nv[i] = uf.leader(i + m) + 1;
    }
    map<int, int> mp;
    int cur = 1;
    rep(i, m) {
        if (nv[i] == 0) continue;
        if (mp.find(nv[i]) == mp.end()) {
            mp[nv[i]] = cur;
            ++cur;
        }
        nv[i] = mp[nv[i]];
    }
    return make_pair(nv, np);
}
 
int main() {
    int n;
    scanf("%d%d", &n, &m);
    int mask[n];
    int offset = 0;
    rep(i, n) {
        string s;
        cin >> s;
        mask[i] = 0;
        rep(j, m) if (s[j] == '#') mask[i] += (1 << j), ++offset;
    }
 
    map<pair<vector<int>, int>, int> dp;  // (state, phase) -> num
    {
        vector<int> init(m, 0);
        dp[{init, 0}] = 0;
    }
    rep(i, n) {
        map<pair<vector<int>, int>, int> nxt;
 
        rep(S, 1 << m) {
            if ((S & mask[i]) != mask[i]) continue;
 
            for (auto [state, val] : dp) {
                auto ns = next_state(state, S);
                if (!ns) continue;
 
                int tmp = val + __builtin_popcount(S);
                if (nxt.find(ns.value()) == nxt.end()) {
                    nxt[ns.value()] = tmp;
                } else {
                    chmin(nxt[ns.value()], tmp);
                }
            }
        }
 
        dp.swap(nxt);
    }
 
    int ans = n * m;
    for (auto [state, val] : dp) {
        auto [v, phase] = state;
        if (*max_element(v.begin(), v.end()) <= 1) {
            chmin(ans, val - offset);
        }
    }
    printf("%d\n", ans);
}

関連

AtCoder Beginner Contest 297 Ex - Diff Adjacent

最後の一手で詰まっていた……。

問題へのリンク

問題

整数 $N$ が与えられる。

  • 要素の総和が $N$
  • 隣り合う項の値が異なる

を満たすような数列の長さの総和を $998244353$ で割った余りを求めよ。

前提

とくになし

解法

問題の第一印象は

  • 「$N-1$ 個の条件を全て満たす数列を数え上げる」問題なので、包除が使えそう
  • 長さを足し上げろと言われているので、数列の長さを変数の肩に乗せた形式的冪級数を考えて、最後に微分すればよさそう

だったので、考察の方針は FPS にロックオンした。

項が隣接している部分は $N - 1$ 箇所あるので、条件に違反しているもの(つまり、等しくなっているもの)の決め打ち方 $2 ^ {N - 1}$ 通りに対して包除原理を適用する。

図に示したとおり、条件に違反する隣接部分を決め打つことは、列を等しい値の区間で分割することと同じだと思える。

よって、長さ $ M $ の列の個数は

$$ [x ^ N] \sum _ { K = 1 } ^ \infty \sum _ {c _ 1 + \dots + c _ K = M} \prod _ {i = 1} ^ K (-1) ^ {c _ i - 1} (x ^ {c_i} + x ^ {2c_i} + \dots ) $$

と書ける。$ c _ 1 + \dots + c _ K = M $ の部分を表現するために変数 $y$ を導入すると、上式は

$$ [x ^ N y ^ M] \sum _ { K = 1 } ^ \infty \Brace{ \sum _ {c = 1} ^ \infty \frac{-(-xy) ^ c}{1 - x ^ c} } ^ K = [x ^ N y ^ M] \frac{1}{1 + \sum _ {c = 1} ^ \infty \frac{(-xy) ^ c}{1 - x ^ c}} $$

と整理できる。

求めたいものは長さの総和だったので、$ y $ の指数を微分で下ろしてくればよい。また、 $ M $ を消すために全体を $1 - y$ で割って $y ^ N$ の係数を取ることにすると、求める値は

$$ [x ^ N y ^ N] \frac{y}{1 - y} \cdot \frac{\partial}{\partial y} \left( \frac{1}{1 + \sum _ {c = 1} ^ \infty \frac{(-xy) ^ c}{1 - x ^ c}} \right) $$

となる。めでたしめでたし……と思いきや、これを計算する方法が分からなくて困った……(ここでコンテスト終了)。

実は、$ 1 - y $ で割る代わりに、$ y $ に $ 1 $ を代入すると係数の和を取ることができる!

よって、

$$ [x ^ N] \left. \frac{\partial}{\partial y} \left( \frac{1}{1 + \sum _ {c = 1} ^ \infty \frac{(-xy) ^ c}{1 - x ^ c}} \right) \right| _ {y = 1} = [x ^ N] \frac{-\sum _ {c = 1} ^N \frac{c(-x) ^ c}{1 - x ^ c} }{\left(1 + \sum _ {c = 1} ^ N \frac{(-x) ^ c}{1 - x ^ c}\right) ^ 2} $$

が答えとなる。

$\frac{(-x) ^ c}{1 - x ^ c}$ のうち非零の係数を持つ項は $\Floor{N / c}$ 個しかないので、$ \sum _ {c = 1} ^ N \frac{(-x) ^ c}{1 - x ^ c} $ は $ O(N \log N) $ 時間で構築できる(調和数のやつ)。

ここまでくれば、あとは FPS ライブラリに投げるだけ。

コード

https://atcoder.jp/contests/abc297/submissions/40497449

FPS ライブラリはここ → rk-library/formal_power_series.hpp at master · Ricky-pon/rk-library · GitHub

int main() {
    int n;
    scanf("%d", &n);
 
    FormalPowerSeries<mint> num(n + 1), den(n + 1);
    For(k, 1, n + 1) {
        for (int i = 1; i * k <= n; ++i) {
            if (k % 2 == 1) {
                num[i * k] += k;
            } else {
                num[i * k] -= k;
            }
        }
    }
    den[0] = 1;
    For(k, 1, n + 1) {
        for (int i = 1; i * k <= n; ++i) {
            if (k % 2 == 1) {
                den[i * k] -= 1;
            } else {
                den[i * k] += 1;
            }
        }
    }
    den *= den;
    den.resize(n + 1);
    num /= den;
    printf("%u\n", num[n].val());
}

関連

AtCoder Beginner Contest 212 H - Nim Counting

H は Hadamard transform の H

問題へのリンク

問題

整数 $ N $ と、整数 $ A_1, A_2, \dots, A_K $ が与えられる。

$ 1 \leq M \leq N $ かつ $ d_i \in \Brace{A_1, A_2, \dots, A_K} $ を満たす長さ $ M $ の整数列 $ d $ のうち、 $$ d_1 \oplus d_2 \oplus \dots \oplus d_M \neq 0 $$ となるものの個数を $998244353$ で割った余りを求めよ。

前提

ここではアダマール変換をブラックボックスとして扱うので、定義は無理に覚えようとしなくてもよい。

定義 $ 1 $(アダマール変換)
$ 2 ^ n \times 2 ^ n $ 行列 $ H _ n $ を $$ \begin{split} H _ 0 &= 1 \\\ H _ {n} &= \begin{bmatrix} H _ {n - 1} & H _ {n - 1} \\ H _ {n - 1} & -H _ {n - 1} \\ \end{bmatrix} \end{split} $$ と定める。$ H_n $ が定める変換をアダマール変換という。また、$ H _ n $ をアダマール行列という。
定理 $ 2 $(逆変換)
アダマール行列 $ H _ n $ に対し、 $$ H _ n H _ n = 2 ^ n I $$ が成り立つ。

ここからは重要。

定理 $ 3 $
$ H _ n $ をアダマール行列、$ a, b $ を長さ $ 2 ^ n $ のベクトル、$ c $ を定数とすると $$ \begin{split} H _ n (ca) &= c H _ n a \\\ H _ n (a + b) &= H _ n a + H _ n b \end{split} $$ が成り立つ。
定理 $ 4 $(畳み込み定理)
$ H _ n $ をアダマール行列、$ a, b $ を長さ $ 2 ^ n $ のベクトルとすると $$ H _ n (a * b) = (H _ n a)(H _ n b) ^ \top $$ が成り立つ。 ただし、$ (a * b) _ k = \sum _ {i \oplus j = k} a _ i b _ j $(つまり、xor 畳み込み)である。
定理 $ 5 $(高速アダマール変換)
$ a $ を長さ $ 2 ^ n $ のベクトルとすると、$ H _ n a $ は $ O(n2 ^ n) $ 時間で計算できる。

解法

$ n = 16 $、$ A = 2 ^ {n} ( > \max A_i ) $ とする。

$ \mathrm{dp}(i, j) := (i$ 個目の山まで見て、総 xor が $ j $ であるような初期状態の個数$)$ とおいて DP したくなるが、状態数 $ O(NA) $、遷移 $ O(A) $ 時間になって困る。

xor 畳み込みを用いて $ \mathrm{dp}(i-1, * ) $ から $ \mathrm{dp}(i, * )$ をまとめて求めるようにすると $ O(NA\log A) $ 時間になるが、まだ遅い。

$ f $ を $$ f _ i = \begin{cases} 1 & i \in \Brace{A_1, A_2, \dots, A_K} \\ 0 & \text{otherwise} \end{cases} $$ を満たす長さ $ A $ のベクトルとする。求めたいものは $ f + f * f + \dots + \underbrace{f * f * \dots * f} _ n $ である。順番に計算していると上と同じで $ O(NA\log A) $ かかってしまうので、一気にアダマール変換することを考える。

$$ \begin{split} (H _ n (f + f * f + \dots + \underbrace{f * f * \dots * f} _ n))_k &= (H _ n(f)) _ k + (H _ n(f * f)) _ k + \dots + (H _ n(\underbrace{f * f * \dots * f} _ n)) _ k \\ &= \sum _ {i = 1} ^ n ( ( H _ n f ) _ k ) ^ i \end{split} $$

であり、等比数列の和の公式を使うと $ 2 $ 行目が $ O( \log N \log mod ) $ で求まる。$ 1 $ つ目のイコールは定理 $ 4 $ から、$ 2 $ つ目のイコールは定理 $ 5 $ から従う。

アダマール変換、逆アダマール変換は $ O(A \log A) $ 時間でできるので、全体で $ O(A (\log A + \log N \log mod)) $ 時間で答えを得ることができる。

コード

https://atcoder.jp/contests/abc212/submissions/38703684

template <typename T>
void fast_walsh_hadamard_transform(vector<T>& f) {
    int n = f.size();
    for (int i = 1; i < n; i <<= 1) {
        for (int j = 0; j < n; j++) {
            if ((j & i) == 0) {
                T x = f[j], y = f[j | i];
                f[j] = x + y, f[j | i] = x - y;
            }
        }
    }
}
template <typename T>
void inverse_fast_walsh_hadamard_transform(vector<T>& f) {
    int n = f.size();
    for (int i = 1; i < n; i <<= 1) {
        for (int j = 0; j < n; j++) {
            if ((j & i) == 0) {
                T x = f[j], y = f[j | i];
                f[j] = (x + y) / 2, f[j | i] = (x - y) / 2;
            }
        }
    }
}
 
int main() {
    int n, K;
    scanf("%d%d", &n, &K);
 
    vector<mint> f(1 << 16, 0);
    rep(i, K) {
        int a;
        scanf("%d", &a);
        f[a] = 1;
    }
 
    fast_walsh_hadamard_transform(f);
    for (auto& x : f) {
        if (x.val() == 0) {
            x = 0;
        } else if (x.val() == 1) {
            x = n;
        } else {
            x = (x.pow(n + 1) - x) / (x - 1);
        }
    }
    inverse_fast_walsh_hadamard_transform(f);
 
    printf("%u\n", accumulate(f.begin() + 1, f.end(), mint(0)).val());
}

関連

AtCoder Beginner Contest 212 G - Power Pair

放置している ABC - G, Ex を解き進めたい

問題へのリンク

問題

素数 $P$ が与えられる。

  • $ 0 \leq x \leq P - 1 $
  • $ 0 \leq y \leq P - 1 $
  • ある整数 $ n $ が存在して、$ x ^ n \equiv _ P y$

を満たす整数の組 $ (x, y) $ の個数を $ 998244353 $ で割った余りを求めよ。

前提

定義 $ 1 $(原始根)
整数 $ p $ に対し、$ r^0, r^1, \dots, r^{p-2} $ が相異なるような $ r \in (\mathbb{Z} / \mathbb{pZ}) ^ \times $ を $ p $ の原始根という。 すなわち、$ r $ は $ (\mathbb{Z} / \mathbb{pZ}) ^ \times $ の生成元である。
定理 $ 2 $
任意の素数に対し、原始根が存在する。

解法

$ x = 0 $ のとき、対応する $ y $ は $ 0 $ のみ。以降は $ x \neq 0 $ の場合を考える。

$ p $ の原始根を $ r $ とすると、ある整数 $ a, b $ を用いて $ x = r ^ a, ~ y = r ^ b $ と書ける。

$$ \begin{split} x ^ n \equiv _ P y &\iff r ^ {an} \equiv _ P r ^ b \\ &\iff an \equiv _ {P - 1} b \end{split} $$

なので、ある整数 $ n $ が存在して $ an \equiv _ {P - 1} b $ を満たすような $ (a, b) $ を数えればよい。

$ \mod P - 1 $ 上で $ \Brace{ka \mid k \in \mathbb{Z}} = \Brace{k\gcd(a, P - 1) \mid k \in \mathbb{Z}} $ が成り立つので、$ a $ を固定したとき、対応する $ b $ は $ \frac{P - 1}{\gcd(a, P - 1)} $ 個存在する。

$ \gcd(a, P - 1) $ ごとにまとめて数えることにすると、答えは

$$ \sum_{g \mid P - 1} \frac{P - 1}{g} \times \# \Brace{x \in \mathbb{Z}/(P-1)\mathbb{Z} \mid g = \gcd(x, P - 1)} $$

となる。

あとは $$ f(g) = \# \Brace{x \in \mathbb{Z}/(P-1)\mathbb{Z} \mid g = \gcd(x, P - 1)} $$ を求めれば終わり。 $$ \# \Brace{x \in \mathbb{Z}/(P-1)\mathbb{Z} \mid g \mid \gcd(x, P - 1)} = \sum_{g \mid h} f(h) $$ は簡単に計算できる($ g $ の倍数の個数に等しい)ので、低速メビウス変換*1をすれば OK。

コード

https://atcoder.jp/contests/abc212/submissions/38687433

int main() {
    lint p;
    scanf("%lld", &p);
 
    vector<lint> ds;
    for (lint i = 1; i * i <= p - 1; ++i) {
        if ((p - 1) % i == 0) {
            ds.push_back(i);
            if (i * i != p - 1) {
                ds.push_back((p - 1) / i);
            }
        }
    }
    sort(ds.begin(), ds.end());
 
    mint dp[ds.size()];
    rrep(i, ds.size()) {
        dp[i] += (p - 1) / ds[i];
        rep(j, i) {
            if (ds[i] % ds[j] == 0) {
                dp[j] -= dp[i];
            }
        }
    }
 
    mint ans = 1;
    rep(i, ds.size()) { ans += dp[i] * ((p - 1) / ds[i]); }
    printf("%u\n", ans.val());
}

関連

*1:造語なので注意