【題解】codeforces766E Mahmoud and a xor trip

題目鏈接

題意:給定一棵樹,樹的每個點有點權,定義2個點u和v之間的距離爲u到v的路徑上的點的點權的異或和。求全體點對(u,v):1<=u<=v<=n的距離和。

分析:考慮按位處理距離和。設有ans[i]個點對的距離的第i位爲1,則距離和=ans[0]*2^0+ans[1]*2^1+...+ans[20]*2^20。從而問題轉化爲點權爲0或1的情況。對於轉化後的問題,我是用樹分治處理的:對於子樹u的所有點對路徑,要麼經過點u,要麼不經過點u,經過點u的通過維護點權異或和爲0,1的鏈數來進行統計,不經過點u的遞歸處理。注意應該把點權轉化爲向量來處理,這樣只用跑一遍樹分治,跑20遍會tle。

        看完題解後發現不用樹分治,直接樹形dp就可以了。這是因爲每個點只要訪問一次就能得到我們需要的信息,所以不需要每層都把所以點都遍歷一遍。

代碼(樹分治)

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
typedef long long LL;
const int maxn=1e5+10,maxl=21;
int n,a[maxn],b[maxn][maxl];
vector<int> G[maxn];
int sz[maxn],root,sum,minmaxs;
LL s[maxl][2],t[maxl][2],ans[maxl];
bool done[maxn];
void dfs_sz(int u,int fu)
{
	sz[u]=1;
	for (int i=0;i<G[u].size();i++)
	{
		int v=G[u][i];
		if (v==fu||done[v]) continue;
		dfs_sz(v,u);sz[u]+=sz[v];
	}
}
void dfs_rt(int u,int fu)
{
	int maxs=sum-sz[u];
	for (int i=0;i<G[u].size();i++)
	{
		int v=G[u][i];
		if (v==fu||done[v]) continue;
		dfs_rt(v,u);maxs=max(maxs,sz[v]);
	}
	if (maxs<minmaxs) {minmaxs=maxs;root=u;}
}
void dfs(int u,int fu,int *o)
{
	for (int l=0;l<maxl;l++) t[l][o[l]]++;
	for (int i=0;i<G[u].size();i++)
	{
		int v=G[u][i];
		if (v==fu||done[v]) continue;
		int o1[maxl];
		for (int l=0;l<maxl;l++) o1[l]=o[l]^b[v][l];
		dfs(v,u,o1);
	}
}
void solve(int u)
{
	dfs_sz(u,-1);
	minmaxs=maxn;sum=sz[u];
	dfs_rt(u,-1);
	u=root;done[u]=1;
	
	//cout<<u<<endl;
	
	memset(s,0,sizeof(s));
	for (int i=0;i<G[u].size();i++)
	{
		int v=G[u][i];
		if (done[v]) continue;
		memset(t,0,sizeof(t));
		dfs(v,u,b[v]);
		for (int l=0;l<maxl;l++)
		{
		    if (b[u][l])
		    {
			    ans[l]+=t[l][0];
			    ans[l]+=t[l][0]*s[l][0]+t[l][1]*s[l][1];
		    }
		    else
		    {
			    ans[l]+=t[l][1];
			    ans[l]+=t[l][0]*s[l][1]+t[l][1]*s[l][0];
		    }
		    s[l][0]+=t[l][0];s[l][1]+=t[l][1];
		}
	}
	for (int i=0;i<G[u].size();i++)
	{
		int v=G[u][i];
		if (done[v]) continue;
		solve(v);
	}
}
int main()
{
	LL ret=0;
	cin>>n;
	for (int i=1;i<=n;i++) scanf("%d",&a[i]),ret+=a[i];
	for (int i=1;i<n;i++)
	{
		int u,v;
		scanf("%d%d",&u,&v);
		G[u].push_back(v);G[v].push_back(u);
	}
	for (int i=1;i<=n;i++)
	    for (int j=0;j<maxl;j++)
	        b[i][j]=(a[i]>>j)&1;
	solve(1);
	for (int l=0;l<maxl;l++) ret+=(1<<l)*ans[l];
	cout<<ret;
	return 0;
}

代碼(樹形dp)

#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int maxn=1e5+10,maxl=21;
int n,a[maxn],f[maxn][maxl][2];
vector<int> G[maxn];
LL ans[maxl];
void dp(int u,int fu)
{
	LL s[maxl][2];memset(s,0,sizeof(s));
	for (int i=0;i<G[u].size();i++)
	{
		int v=G[u][i];
		if (v==fu) continue;
		dp(v,u);
		for (int l=0;l<maxl;l++)
		{
		    if (a[u]&(1<<l))
		    	ans[l]+=f[v][l][0]+f[v][l][0]*s[l][0]+f[v][l][1]*s[l][1];
			else
			    ans[l]+=f[v][l][1]+f[v][l][0]*s[l][1]+f[v][l][1]*s[l][0];
			//if (l==10&&u==2) cout<<f[v][l][1]<<" "<<ans[l]<<endl;
			s[l][0]+=f[v][l][0];
			s[l][1]+=f[v][l][1];
	    }
	}
	for (int l=0;l<maxl;l++)
	{
	    int d=(a[u]>>l)&1;
		f[u][l][0]=s[l][0^d];f[u][l][1]=s[l][1^d];f[u][l][d]++;
		//if (l==10&&u==3) cout<<d<<endl;
	}
}
int main()
{
	LL ret=0;
	cin>>n;
	for (int i=1;i<=n;i++) scanf("%d",&a[i]),ret+=a[i];
	for (int i=1;i<n;i++)
	{
		int u,v;scanf("%d%d",&u,&v);
		G[u].push_back(v);G[v].push_back(u);
	}
	dp(1,-1);
	//for (int l=0;l<maxl;l++) cout<<ans[l]<<" ";
	for (int l=0;l<maxl;l++) ret+=ans[l]*(1<<l);
	cout<<ret;
	return 0;
}


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章