#include<bits/stdc++.h>
#define cs const
#define pb push_back
using namespace std;
int read(){
int cnt = 0, f = 1; char ch = 0;
while(!isdigit(ch)){ ch = getchar();if(ch == '-') f = -1;}while(isdigit(ch)) cnt = cnt * 10 + (ch-'0'), ch = getchar();
return cnt * f;}
cs int N = 1e5 + 50;
cs int Mod = 998244353;
int add(int a, int b){ return a + b >= Mod ? a + b - Mod: a + b;}int dec(int a, int b){ return a - b < 0 ? a - b + Mod: a - b;}int mul(int a, int b){ return 1ll * a * b % Mod;}void Add(int &a, int b){ a = add(a, b);}void Dec(int &a, int b){ a = dec(a, b);}void Mul(int &a, int b){ a = mul(a, b);}int ksm(int a, int b){ int as=1;for(;b;b>>=1,a=mul(a,a))if(b&1) as=mul(as,a); return as;}
int n, A, coe[N], fac[N], ifac[N];
int inv(int a){ return a?mul(ifac[a],fac[a-1]):1;}int C(int n, int m){if(n<0||m<0||n<m) return 0; return mul(fac[n],mul(ifac[n-m],ifac[m]));}void fac_init(int n){
fac[0] = fac[1] = ifac[0] = ifac[1] = coe[1] = coe[2] = 1;for(int i=1; i<=n; i++) fac[i] = mul(fac[i-1],i);
ifac[n]=ksm(fac[n],Mod-2);for(int i=n-1; i>=2; i--) ifac[i]=mul(ifac[i+1],i+1);for(int i=3; i<=n; i++) coe[i]=ksm(i,i-2);}
#define poly vector<int>
cs int K = 18;
poly w[K+1];
void NTT_init(){for(int i=1; i<=K; i++) w[i].resize(1<<(i-1));
int wn=ksm(3,(Mod-1)/(1<<K)); w[K][0]=1;for(int i=1; i<(1<<(K-1)); i++) w[K][i]=mul(w[K][i-1],wn);for(int i=K-1;i;i--)for(int j=0;j<(1<<(i-1));j++) w[i][j]=w[i+1][j<<1];}
int f[N], g[N];
int up, bit; poly rev;
void init(int deg){
up=1; bit=0;while(up<deg) up<<=1,++bit; rev.resize(up);for(int i=0; i<up; i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));}void NTT(poly &a, int typ){for(int i=0;i<up;i++)if(i<rev[i])swap(a[i],a[rev[i]]);for(int i=1,l=1; i<up; i<<=1,++l)for(int j=0; j<up; j+=(i<<1))for(int k=0; k<i; k++){
int x=a[k+j], y=mul(w[l][k],a[k+j+i]);
a[k+j]=add(x,y); a[k+j+i]=dec(x,y);}if(typ==-1){reverse(a.begin()+1,a.end());for(int i=0,iv=ksm(up,Mod-2);i<up;i++)Mul(a[i],iv);}}poly operator * (poly a, poly b){
int deg = a.size()+b.size()-1;init(deg);
a.resize(up); b.resize(up);NTT(a,1);NTT(b,1);for(int i=0; i<up; i++)Mul(a[i],b[i]);NTT(a,-1);
a.resize(deg); return a;}void work(int l, int r){if(l==r) return;
int mid = (l+r) >> 1;work(l,mid);
poly A, B;for(int i=l;i<=mid;i++) A.pb(f[i]);
B.pb(0);for(int i=1;i<=r-l;i++) B.pb(g[i]);
A = A * B;for(int i=mid+1;i<=r;i++)Add(f[i],mul(inv(i),A[i-l]));work(mid+1,r);}int main(){freopen("forest.in","r",stdin);freopen("forest.out","w",stdout);
n = read(), A = read(); fac_init(n); NTT_init();for(int i=1; i<=A; i++) g[i]=mul(ifac[i-1],coe[i]);
f[0]=1;work(0,n); cout<<mul(fac[n],f[n]); return 0;}
T2:
一个字符串,问有多少 S[l,r] 满足其可以分成相同的 K 段,n≤3e5
将问题转换一下就是 lcs(i,j)≥(j−i)∗(k−1),于是考虑建出 SAM 然后枚举 i,lca
在 lca 处查询除 i 子树中满足 j∈(i,i+⌊k−1len⌋] 的个数
不能暴力跳,考虑链分治,将一个询问 i 放到若干条重链上,每条重链需要做一个前缀,即将一个前缀的轻儿子全部插入,那么我们就从上往下做,暴力插入轻儿子
但注意到这里的 len 不一样,转换一下发现就是 i∈[j−⌊k−1len⌋,j),这样每一个子树的 len 就是一样的了,所以就是区间加单点查询,然后到链尾要特殊处理,发现就是一个子树查询
考场脑子抽了写了个线段树合并,其实离线下来用一个树状数组加差分就可以实现子树查询
#include<bits/stdc++.h>
#define cs const
#define pb push_back
using namespace std;
int read(){
int cnt = 0, f = 1; char ch = 0;
while(!isdigit(ch)){ ch = getchar();if(ch == '-') f = -1;}while(isdigit(ch)) cnt = cnt * 10 + (ch-'0'), ch = getchar();
return cnt * f;}
cs int N = 6e5 + 50;
typedef long long ll;
int n, K, ps[N], rt[N]; char S[N];
ll Ans;
namespace SGT{
cs int N = ::N * 40;
int ls[N], rs[N], sz[N], nd;
#define mid ((l+r)>>1)
void ins(int &x, int l, int r, int p){if(!x) x = ++nd; ++sz[x];if(l == r) return;(p<=mid)?ins(ls[x],l,mid,p):ins(rs[x],mid+1,r,p);}int merge(int x, int y){if(!x||!y) return x|y; int nx=++nd;
ls[nx]=merge(ls[x],ls[y]);
rs[nx]=merge(rs[x],rs[y]);
sz[nx]=sz[ls[nx]]+sz[rs[nx]]; return nx;}int query(int x, int l, int r, int L, int R){if(!x) return 0;if(L<=l&&r<=R) return sz[x]; int as=0;if(L<=mid) as+=query(ls[x],l,mid,L,R);if(R>mid) as+=query(rs[x],mid+1,r,L,R); return as;}int qry(int x, int l, int r){if(l>n) l=n;if(r>n) r=n;if(l>r) return 0;
return query(x,1,n,l,r);}}namespace SAM{
int ch[N][26],lk[N],len[N],r[N],nd=1,las=1;
int extend(int c, int k){
int now=++nd, p=las; len[now]=len[p]+1; r[now]=k;for(;p&&!ch[p][c];p=lk[p]) ch[p][c]=now;if(!p) lk[now] = 1;
else{
int q = ch[p][c];if(len[q] == len[p]+1) lk[now] = q;
else{
int cl=++nd; len[cl]=len[p]+1; lk[cl]=lk[q];memcpy(ch[cl],ch[q],sizeof(ch[q]));for(;p&&ch[p][c]==q; p=lk[p]) ch[p][c]=cl;
lk[now]=lk[q]=cl;}} las = now; return now;}void radix_sort(){
static int bin[N], A[N];for(int i=1; i<=nd; i++) bin[len[i]]++;for(int i=1; i<=n; i++) bin[i]+=bin[i-1];for(int i=nd; i>=1; i--) A[bin[len[i]]--]=i;for(int i=nd; i>=2; i--){
int u=A[i];if(r[u]) Ans+=(ll)SGT::qry(rt[u],r[u],r[u]+len[u]/(K-1));if(r[u])SGT::ins(rt[u],1,n,r[u]);
rt[lk[u]]=SGT::merge(rt[lk[u]],rt[u]);}}}
using SAM::len;
using SAM::r;
vector<int> G[N];
vector<int> v[N];
int fa[N], top[N], son[N], sz[N];
void pre_dfs(int u){
sz[u] = 1;
for(int v : G[u]){
fa[v] = u; pre_dfs(v); sz[u] += sz[v];if(sz[son[u]] < sz[v]) son[u] = v;}}void dfs(int u, int tp){
top[u] = tp;if(son[u])dfs(son[u], tp);for(int v: G[u])if(v ^ son[u])dfs(v, v);}void ins(int x, int c){while(x){
v[x].pb(c), x=top[x];
if(fa[x]){
Ans-=SGT::qry(rt[x],c+1,c+len[fa[x]]/(K-1));
Ans+=SGT::qry(rt[fa[x]],c+1,c+len[fa[x]]/(K-1));} x=fa[x];}}namespace BIT{
int c[N];
void add(int x, int v){if(x<1) x=1;if(x>n) return;for(;x<=n;x+=x&-x) c[x]+=v;}int ask(int x){ int as=0;for(;x;x-=x&-x) as+=c[x]; return as;}}void subwork(int u, int deg, int c){if(r[u])BIT::add(r[u]-deg,c),BIT::add(r[u]+1,-c);for(int v: G[u])subwork(v, deg, c);}void work(int u){for(int x = u; x; x = son[x])for(int v: G[x])if(v ^ son[x])work(v);for(int x = u; x; x = son[x]){for(int c: v[x]) Ans += (ll)BIT::ask(c);for(int v: G[x])if(v ^ son[x])subwork(v,len[x]/(K-1),1);}for(int x = u; x; x = son[x])for(int v: G[x])if(v ^ son[x])subwork(v,len[x]/(K-1),-1);}int main(){freopen("sutoringu.in","r",stdin);freopen("sutoringu.out","w",stdout);
n = read(); K = read();scanf("%s",S+1);for(int i=1; i<=n; i++) ps[i]=SAM::extend(S[i]-'a',i);SAM::radix_sort();for(int i=2; i<=SAM::nd; i++) G[SAM::lk[i]].pb(i);
pre_dfs(1);dfs(1,1);for(int i = 1; i <= n; i++)ins(ps[i],i);work(1); cout<<Ans; return 0;}
T3:
求等腰直角三角形的个数
枚举两个点搞出第三个点,考场用 map 跑了 5 点几秒吓死,改成 set 快了很多
#include<bits/stdc++.h>
#define cs const
#define pb push_back
using namespace std;
int read(){
int cnt = 0, f = 1; char ch = 0;
while(!isdigit(ch)){ ch = getchar();if(ch == '-') f = -1;}while(isdigit(ch)) cnt = cnt * 10 + (ch-'0'), ch = getchar();
return cnt * f;}
cs int N = 3e3 + 50;
typedef pair<int,int> pi;
int n, x[N], y[N];
struct node{
int x, y; mutable int v; node(int _x=0, int _y=0, int _v=0){ x=_x; y=_y; v=_v;}bool operator < (cs node &a)cs{ return x==a.x?y<a.y:x<a.x;}};
set<node> S;
typedef set<node>::iterator It;
int main(){freopen("triangle.in","r",stdin);freopen("triangle.out","w",stdout);
n = read();for(int i=1; i<=n; i++){
x[i]=read(), y[i]=read(); It it=S.find(node(x[i],y[i]));if(it==S.end()) S.insert(node(x[i],y[i],1));
else ++it->v;}
int as = 0;for(int i=1; i<=n; i++)for(int j=i+1; j<=n; j++){
int X=x[i]+x[j]+y[i]-y[j];
int Y=y[i]+y[j]+x[j]-x[i];
if(!(X&1)&&!(Y&1)){
X>>=1, Y>>=1;
It it=S.find(node(X,Y));if(it!=S.end()) as+=it->v;}
X=x[i]+x[j]+y[j]-y[i];
Y=y[i]+y[j]+x[i]-x[j];
if(!(X&1)&&!(Y&1)){
X>>=1, Y>>=1;
It it=S.find(node(X,Y));if(it!=S.end()) as+=it->v;}} cout<<as;}