You are given a permutation {1,2,3,...,n}. Remove m of them one by one, and output the number of inversion pairs before each removal. The number of inversion pairs of an array A is the number of ordered pairs (i,j) such that i < j and A[i] > A[j].
Input
The input contains several test cases. The first line of each case contains two integers n and m (1<=n<=200,000, 1<=m<=100,000). After that, n lines follow, representing the initial permutation. Then m lines follow, representing the removed integers, in the order of the removals. No integer will be removed twice. The input is terminated by end-of-file (EOF). The size of input file does not exceed 5MB.
Output
For each removal, output the number of inversion pairs before it.
Sample Input
5 4 1 5 3 4 2 5 1 4 2
Output for the Sample Input
5 2 2 1
題意:給出一個1~n的排列A,要求按照某種順序刪除一些數(其他數順序不變),輸出每次刪除之前逆序對的數目。
思路:O(nlogn)求出初始逆序數。然後樹狀數組套靜態BST。
樹狀數組的每個元素爲一棵靜態BST。刪除一個元素,只要計算出該元素所貢獻的逆序數則可得到刪除後的逆序。
也就是前面比它大的和後面比它小的。只需在樹狀數組的對應的BST上跑即可統計出來。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define maxn 200080
#define maxm 18000000
#define LL long long int
int lson[maxm],rson[maxm],key[maxm],vis[maxm],Size[maxm];
int a[maxn],b[maxn],ope[maxn],root[maxn],Pos[maxn];//Pos[maxn]表示每個整數的位置
int cnt;
void init()
{
cnt = 0;
Size[0] = lson[0] = rson[0] = vis[0] = 0;
}
int lowbit(int x)
{
return x & (-x);
}
void build(int pos,int l,int r)
{
if(l > r) return;
int mid = (l+r) >> 1;
Size[pos] = 1;
key[pos] = b[mid];
vis[pos] = 1;
if(l < mid)
{
lson[pos] = ++cnt;
build(cnt,l,mid-1);
}
else lson[pos] = 0;
if(r > mid)
{
rson[pos] = ++cnt;
build(cnt,mid+1,r);
}
else rson[pos] = 0;
Size[pos] += Size[lson[pos]];
Size[pos] += Size[rson[pos]];
}
void Build(int n)
{
for(int i = 1;i <= n;i++)
{
int u = i-lowbit(i)+1,v = i;
for(int j = u;j <= v;j++)
{
b[j-u+1] = a[j];
}
sort(b+1,b+1+v-u+1);
root[i] = ++cnt;
key[cnt] = b[(1+v-u+1)/2];
build(cnt,1,v-u+1);
}
}
struct ST
{
int l,r,sum;
}st[maxn<<2];
void buildtree(int id,int l,int r)
{
st[id].l = l,st[id].r = r;
st[id].sum = 0;
if(l == r) return;
int mid = (l+r) >> 1;
buildtree(id<<1,l,mid);
buildtree(id<<1|1,mid+1,r);
}
void PushUp(int id)
{
st[id].sum = st[id<<1].sum + st[id<<1|1].sum;
}
void update(int id,int pos)
{
if(st[id].l == pos && st[id].r == pos)
{
st[id].sum = 1;
return;
}
if(st[id<<1].r >= pos)
update(id<<1,pos);
else update(id<<1|1,pos);
PushUp(id);
}
int query(int id,int l,int r)
{
if(st[id].l == l && st[id].r == r)
return st[id].sum;
if(st[id<<1].r >= r)
return query(id<<1,l,r);
else if(st[id<<1|1].l <= l)
return query(id<<1|1,l,r);
else return query(id<<1,l,st[id<<1].r) + query(id<<1|1,st[id<<1|1].l,r);
}
void remove(int pos,int k)
{
if(!pos) return;
Size[pos]--;
if(key[pos] == k) vis[pos] = 0;
else if(key[pos] > k) remove(lson[pos],k);
else remove(rson[pos],k);
}
void gao(int pos,int n)
{
for(int i = pos;i <= n;i += lowbit(i))
{
int rot = root[i];//樹的根
remove(rot,a[pos]);
}
}
int LeftMore(int pos,int k)//左邊比我大的數有多少個
{
if(!pos) return 0;
if(key[pos] == k) return Size[rson[pos]];
else if(key[pos] > k) return vis[pos] + Size[rson[pos]] + LeftMore(lson[pos],k);
else return LeftMore(rson[pos],k);
}
int RightLess(int pos,int k)
{
if(!pos) return 0;
if(key[pos] == k) return Size[lson[pos]];
else if(key[pos] > k) return RightLess(lson[pos],k);
else return Size[lson[pos]] + vis[pos] + RightLess(rson[pos],k);
}
int count(int pos,int n)
{
int sum = 0,lsum = 0;
for(int i = pos;i > 0;i -= lowbit(i))
{
int rot = root[i];
lsum += Size[rot];
sum += LeftMore(rot,a[pos]);
}
lsum -= sum;
for(int i = n;i > 0;i -= lowbit(i))
{
int rot = root[i];
sum += RightLess(rot,a[pos]);
}
sum -= lsum;
return sum;
}
int main()
{
//freopen("in.txt","r",stdin);
int n,m;
while(scanf("%d%d",&n,&m)==2)
{
init();
for(int i = 1;i <= n;i++)
{
scanf("%d",&a[i]);
Pos[a[i]] = i;
}
for(int i = 1;i <= m;i++) scanf("%d",&ope[i]);
Build(n);//建好樹了
//求出初始逆序數
buildtree(1,1,n);
LL ans = 0;
for(int i = 1;i <= n;i++)
{
ans += query(1,a[i],n);
update(1,a[i]);
}
//逆序數求出來了,接下來就是刪除了。
for(int i = 1;i <= m;i++)
{
int pos = Pos[ope[i]];//刪除第幾個數
gao(pos,n);//把這個數刪除了
//記下來得計算這個數貢獻的逆序數,然後減去
printf("%lld\n",ans);
ans -= count(pos,n);
}
}
return 0;
}