題意:給定一個長度爲n的數組(數組元素∈{1,2,...,n}),對k∈{1,2,…,n},求最小的ans[k],使得數組可以分爲連續的ans[k]段,每段內不相同的元素個數都不超過k。
分析:爲了使ans[k]最小,我們可以貪心地來分段,即對起點l,取最大的r,使得a[l],...,a[r]中不相同的元素個數不超過k。可以來估計一下總段數(即ans[1]+...+ans[n])的上界。顯然有ans[k]≤n/k,所以ans[1]+...+ans[n]≤n(1/1+1/2+...+1/n)≈nlgn。因此只要能找到一種對固定的起點l,快速(如O(lgn))求出終點r的算法,我們就能較快地(如O(n(lgn)^2))解決這個問題。
基於主席樹的做法是比較容易想到的,所以這裏不再贅述。接下來詳細講解另一種和主席樹時空複雜度相同但代碼量小且常數小的做法。
假設對於k,數組分成的ans[k]段爲[l_k1,r_k1]、[l_k2,r_k2]、...、[l_kans[k],r_kans[k]]。那麼我們的程序執行過程中就需要求對起點l_kj的最大終點r_kj使得[l_kj,r_kj]是不相同元素個數不超過k的以l_kj爲起點的最大區間。考慮換一種順序來求解這一系列問題(一系列問題指一系列求終點的問題)。先解決以1爲起點的所有問題(此時要考慮的k取遍1到n),再解決以2爲起點的所有問題(此時要考慮的k不一定取遍1到n),依次類推。改變問題的處理順序後,我們就可以簡單地用一個樹狀數組來求解這些問題。具體做法是:在解決了以1~i-1爲起點的問題後,用一個數組array[]來標記a[i]、...、a[n]中每個數第一次出現的位置(如a[i]第一次出現在位置i,則令array[i]=1),則array[]的前r項和即爲a[l],...,a[r]中不相同的元素個數(注意array[1],...,array[i-1]均爲0)。用樹狀數組維護array[],那麼就可以通過倍增的方式O(lgn)地求出起點爲i,不相同元素不超過k的最大終點r(而不是二分再用樹狀數組求前綴和)。容易想到以O(lgn)的時間代價將以i爲起點的樹狀數組轉化爲以i+1爲起點的樹狀數組的做法。顯然這個做法的時空複雜度也是O(n(lgn)^2)。
代碼(主席樹)
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+10;
struct tnode
{
int lc,rc,s;
}rt[maxn],nd[maxn*40],temp;
int n,a[maxn],sz,Last[maxn];
void build(tnode &o,int l,int r)
{
o.s=0;
if (l==r) return ;
int mid=(l+r)/2;
o.lc=++sz;o.rc=++sz;
build(nd[o.lc],l,mid);build(nd[o.rc],mid+1,r);
}
void updata(tnode &o1,tnode o2,int l,int r,int p,int val)
{
o1.s=o2.s+val;
if (l==r) return ;
int mid=(l+r)/2;
if (p<=mid)
{
o1.rc=o2.rc;
o1.lc=++sz;
updata(nd[o1.lc],nd[o2.lc],l,mid,p,val);
}
else
{
o1.lc=o2.lc;
o1.rc=++sz;
updata(nd[o1.rc],nd[o2.rc],mid+1,r,p,val);
}
}
int query(tnode o,int l,int r,int k)
{
if (o.s<=k) return r;
if (l==r) return l-1;
int mid=(l+r)/2;
if (nd[o.lc].s<=k) return query(nd[o.rc],mid+1,r,k-nd[o.lc].s);
else return query(nd[o.lc],l,mid,k);
}
int main()
{
cin>>n;
for (int i=1;i<=n;i++) scanf("%d",&a[i]);
build(rt[n+1],1,n);
for (int i=n;i>=1;i--)
{
if (Last[a[i]])
{
updata(temp,rt[i+1],1,n,Last[a[i]],-1);
updata(rt[i],temp,1,n,i,1);
}
else updata(rt[i],rt[i+1],1,n,i,1);
Last[a[i]]=i;
}
for (int k=1;k<=n;k++)
{
int ans=0;
for (int st=1;st<=n;)
{
//cout<<st<<" ";system("pause");
ans++;
st=query(rt[st],1,n,k)+1;
}
printf("%d ",ans);
}
return 0;
}
代碼(另解)
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+10;
int n,a[maxn],c[maxn],ans[maxn];
set<int> pos[maxn];
vector<int> S[maxn];
int lowbit(int x)
{
return x&(-x);
}
void add(int x,int val)
{
while (x<=n)
{
c[x]+=val;
x+=lowbit(x);
}
}
int query(int k)
{
int ret=0;
for (int l=20;l>=0;l--)
if (ret+(1<<l)<=n&&c[ret+(1<<l)]<=k)
{
ret+=(1<<l);k-=c[ret];
}
return ret;
}
void work(int x)
{
if (!pos[x].empty())
{
add(*pos[x].begin(),1);
pos[x].erase(pos[x].begin());
}
}
int main()
{
cin>>n;
for (int i=1;i<=n;i++) scanf("%d",&a[i]),pos[a[i]].insert(i);
for (int i=1;i<=n;i++)
{
S[1].push_back(i);
work(i);
}
for (int i=1;i<=n;i++)
{
for (int j=0;j<S[i].size();j++)
{
int x=S[i][j];
int y=query(x)+1;
S[y].push_back(x);
ans[x]++;
}
add(i,-1);
work(a[i]);
}
for (int i=1;i<=n;i++) printf("%d ",ans[i]);
return 0;
}