题目
Input
一行两个正整数,分别表示n和m。
Output
一行一个正整数表示答案。
Sample Input
样例1:
2 2
样例2:
10 3
Sample Output
样例1:
129140165
样例2:
7008635
Data Constraint
对于30%的数据n*m<=16
对于另外20%的数据m<=4
对于100%的数据n<=50,m<=8
思路
考虑知道点权怎么求
建出trie,对于trie的每个节点,如果它既有0儿子,又有1儿子,那么这两棵子树分别联通后,要找一条最小的边把它们连起来。
于是我们可以枚举这个节点的深度,再枚举它的左子树和右子树的大小,问题转换为:
有x和y个k位二进制数,求它们之间的异或最小值。
设f[x][y][z][u]为有x和y个k位二进制数,最小值≥u的方案数。
代码
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int mod = 258280327;
ll power(ll x,ll y) {
ll s = 1;
for(; y; y /= 2,x = x * x % mod)
if(y & 1) s = s * x % mod;
return s;
}
const int N = 55;
const int M = 260;
int n,m;
ll c[M][M];
ll f[N][N][9][M];
int a2[9];
ll b[N];
void build(int n) {
for(int i=0; i<=n; i++) {
c[i][0] = 1;
for(int j=1; j<=i; j++) c[i][j] = (c[i - 1][j - 1] + c[i - 1][j]) % mod;
}
}
int main()
{
build(256);
a2[0]=1;
for(int i=1; i<=8; i++) a2[i]=a2[i-1]*2;
scanf("%d %d",&n,&m);
for(int i=0; i<=n; i++) for(int j=0; j<=n; j++) for(int k=0; k<=m; k++) if(!i||!j||!k)
{
ll s = 1;
for(int u=1; u<=i+j; u++) s=s*a2[k]%mod;
for(int u=0; u<a2[k]; u++) f[i][j][k][u]=s;
}
for(int i=1; i<=n; i++) for(int j=1; j<=n-i; j++) for(int k=1; k<=m; k++)
{
for(int I=0; I<=i; I++) for(int J=0; J<=j; J++)
{
int t=a2[k-1];
if((I&&J)||((i-I)&&(j-J))) t=0;
ll xs=c[i][I]*c[j][J]%mod;
for(int u=0; u<a2[k-1]; u++)
{
if(t==0) f[i][j][k][u + t] = (f[i][j][k][u + t] + f[I][J][k - 1][u] * f[i - I][j - J][k - 1][u] % mod * xs) % mod;
else f[i][j][k][u + t] = (f[i][j][k][u + t] + f[I][j - J][k - 1][u] * f[i - I][J][k - 1][u] % mod * xs) % mod;
}
}
ll s = f[i][j][k][a2[k - 1]];
for(int u=0; u<a2[k-1]; u++) f[i][j][k][u]=(f[i][j][k][u]+s)%mod;
}
ll ans = 0;
for(int i=1; i<=m; i++)
{
b[0] = 1;
b[1] = a2[m] - a2[i];
for(int j=2; j<=n; j++) b[j] = b[j - 1] * b[1] % mod;
for(int j=1; j<=n; j++) for(int k=1; k<=n-j; k++)
{
ll xs = c[n][j] * c[n - j][k] % mod * a2[m - i] % mod * b[n - j - k] % mod;
ll xs2 = 1;
for(int u=1; u<=j+k; u++) xs2 = xs2 * a2[i - 1] % mod;
ans = (ans + xs * a2[i - 1] % mod * xs2) % mod;
ll s = 0;
for(int u=1; u<a2[i-1]; u++) s = (s + f[j][k][i - 1][u]) % mod;
ans = (ans + xs * s) % mod;
}
}
ll v = power(a2[m],n);
printf("%lld\n",ans * power(v,mod - 2) % mod);
}