【bzoj4006】[JLOI2015]管道連接 斯坦納樹+子集dp

f[i][S]表示使S集合中的點聯通的最小值
g[S]表示把顏色爲集合S的點聯通的最小值
g[S]=min{g[s]+g[S-S]}
g[S]=f[i][SS] SS中的點均爲S集合中的顏色

跑的死慢,感覺正解應該比我的簡單吧


#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<cmath>
#include<iostream>
#include<algorithm>
#include<vector>
#define maxn 1010
#define inf 1000000000

 using namespace std;
 
 struct yts
 {
 	int x,id;
 }a[20];
 
 int f[maxn][1024],g[1024];
 int q[maxn];
 bool vis[maxn],flag[maxn];
 int bin[20];
 int head[maxn],to[2*3010],next[2*3010],len[2*3010];
 int n,m,num,k,cnt,CNT;
 vector<int> v[11];
 
 void addedge(int x,int y,int z)
 {
 	num++;to[num]=y;len[num]=z;next[num]=head[x];head[x]=num;
 }
 
 bool cmp(yts x,yts y) {return x.x<y.x;}
 
 void solve(int SS,int num)
 {
 	for (int i=1;i<=n;i++) for (int S=0;S<bin[num];S++) f[i][S]=inf;
 	num=0;
 	for (int i=1;i<=n;i++) if (flag[i]) f[i][bin[num++]]=0;
 	for (int S=1;S<bin[num];S++)
 	{
 		for (int i=1;i<=n;i++)
 		  for (int x=S&(S-1);x;x=(x-1)&S)
 		    f[i][S]=min(f[i][S],f[i][x]+f[i][S^x]);
 		memset(vis,0,sizeof(vis));
 		int l=0,r=0;
 		for (int i=1;i<=n;i++) if (f[i][S]!=inf) q[++r]=i,vis[i]=1;
 		while (l!=r)
 		{
 			l++;if (l==maxn) l=0;
 			int x=q[l];
 			for (int p=head[x];p;p=next[p])
 			  if (f[to[p]][S]>f[x][S]+len[p])
 			  {
 			  	f[to[p]][S]=f[x][S]+len[p];
 			  	if (!vis[to[p]])
 			  	{
 			  		r++;if (r==maxn) r=0;
 			  		q[r]=to[p];vis[to[p]]=1;
 			  	}
 			  }
 			vis[x]=0;
 		}
 	}
 	g[SS]=inf;
 	for (int i=1;i<=n;i++) g[SS]=min(g[SS],f[i][bin[num]-1]);
 }
 
 int main()
 {
 	scanf("%d%d%d",&n,&m,&k);
 	for (int i=1;i<=m;i++)
 	{
 		int x,y,z;
 		scanf("%d%d%d",&x,&y,&z);
 		addedge(x,y,z);addedge(y,x,z);
 	}
 	for (int i=1;i<=k;i++) scanf("%d%d",&a[i].x,&a[i].id);
 	sort(a+1,a+k+1,cmp);
 	for (int i=1;i<=k;i++)
 	{
 		int j=i;
 		while (j<k && a[j+1].x==a[i].x) j++;
 		if (i!=j) {cnt++;for (int p=i;p<=j;p++) v[cnt].push_back(a[p].id);}
 		i=j;
 	}
 	bin[0]=1;for (int i=1;i<=10;i++) bin[i]=bin[i-1]<<1;
 	for (int S=1;S<bin[cnt];S++)
 	{
 		int now=0;
 		memset(flag,0,sizeof(flag));
 		for (int i=1;i<=cnt;i++)
 		  if (S&(1<<(i-1)))
 		    for (int j=0;j<v[i].size();j++) flag[v[i][j]]=1,now++;
 		solve(S,now);
 		for (int x=S&(S-1);x;x=(x-1)&S) g[S]=min(g[S],g[x]+g[S-x]);
 	}
 	printf("%d\n",g[bin[cnt]-1]);
 	return 0;
 }


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章