有個粒子初始在 \(0\) 位置,\(1\cdots n\) 位置分別爲有一個對撞器,如果在 \(0\) 位置則向右,如果在 \(n + 1\) 位置則向左。每個對撞器有一個 \(01\) 串,初始所有對撞器的指針都在開頭,當粒子走到 \(i\) 位置時,對撞器所指的值爲 \(0\) 則不改變方向,否則反向,指針指向下一個位置,如果在串的末尾則指向開頭。求最小的週期長度 \(c\) 滿足任意 \(t\) 時間和 \(t + c\) 時間粒子在同一位置。
\(1\le n \le 10^6\),\(\sum |s_i|\le 10^6\)。
注意到對於一個位置,無論在右邊轉了多久,回到這裏後和直接從右邊回來是一樣。左邊同理。所以我們只用考慮 \(0,1,2\) 這三個位置。
還有一個顯然的事實是每個粒子只用保留它的最小整週期,然後你就可以跑一個暴力,求出一個週期從左邊進入 \(a_i\) 次,往右邊走出去 \(b_i\) 次。顯然這個過程是對稱的。
設 \(f_i\) 代表最後過程中 \(i\to {i+1}\) 的次數,由於左邊和右邊會右七七八八的破事,所以 \(i\) 這個位置可能進進出出多個週期,所以應該 \(\frac{f_i}{f_{i+1}}=\frac{a_i}{b_i}\)。使用主元法,用 \(f_0\) 表示出所有 \(f_i=f_0\prod_{j=1}^{i}\frac{b_j}{a_j}\)。我們需要構造這個 \(f_0\) 使得每一個 \(f_i\) 是整數,並且 \(b_i\mid f_i\)。根據每個質因子考慮,隨時把不夠的部分補進 \(f_0\) 裏。不妨設 \(b_n\not=0\),統計這一部分即可。
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e6 + 5, mod = 998244353;
int qmod(int x) { return x >= mod ? x - mod : x; }
int ksm(int a, int b = mod - 2)
{
int ret = 1;
for (; b; b >>= 1, a = 1ll * a * a % mod) if (b & 1) ret = 1ll * ret * a % mod;
return ret;
}
template <typename T>
void read(T &x)
{
T sgn = 1;
char ch = getchar();
for (; !isdigit(ch); ch = getchar()) if (ch == '-') sgn = -1;
for (x = 0; isdigit(ch); ch = getchar()) x = x * 10 + ch - '0';
x *= sgn;
}
int n, prime[maxn], cnt, mn[maxn];
bool vis[maxn];
char s[maxn];
int nxt[maxn], a[maxn], b[maxn];
int mx[maxn], num[maxn];
void sieve(int mx)
{
for (int i = 2; i <= mx; i++)
{
if (!vis[i]) prime[++cnt] = i, mn[i] = i;
for (int j = 1; j <= cnt && prime[j] * i <= mx; j++)
{
vis[i * prime[j]] = 1;
mn[i * prime[j]] = prime[j];
if (i % prime[j] == 0) break;
}
}
}
int main()
{
read(n); sieve(1000000);
for (int _ = 1; _ <= n; _++)
{
scanf("%s", s + 1);
int len = strlen(s + 1);
for (int i = 2, j = 0; i <= len; i++)
{
while (j && s[i] != s[j + 1]) j = nxt[j];
if (s[i] == s[j + 1]) j++;
nxt[i] = j;
}
int per = len % (len - nxt[len]) == 0 ? len - nxt[len] : len;
int cur = 0, dir = 1, pos = 1;
do
{
cur += dir;
if (cur == 1)
{
dir == 1 ? a[_]++ : b[_]++;
if (s[pos] == '1') dir = -dir;
pos = pos % per + 1;
}
else dir = cur == 0 ? 1 : -1;
} while (cur != 0 || pos != 1);
// if (n == 50)
// printf("! %d %d\n", a[_], b[_]);
}
for (int i = 1; i <= n; i++)
{
int now = a[i];
while (now > 1)
{
int p = mn[now];
while (now % p == 0) now /= p, num[p]--;
mx[p] += max(0, -num[p]);
num[p] = max(num[p], 0);
}
if (!b[i]) break;
now = b[i];
while (now > 1)
{
int p = mn[now];
while (now % p == 0) now /= p, num[p]++;
mx[p] = max(mx[p], -num[p]);
}
}
int f0 = 1;
for (int i = 1; i <= cnt; i++) f0 = 1ll * f0 * ksm(prime[i], mx[prime[i]]) % mod;
int ans = f0;
for (int i = 1; i <= n; i++)
{
f0 = 1ll * f0 * b[i] % mod * ksm(a[i]) % mod;
ans = qmod(ans + f0);
} printf("%d\n", qmod(ans + ans));
return 0;
}