永琳的竹林迷徑(path)
題目描述
竹林可以看作是一個n 個點的樹,每個邊有一個邊長wi,其中有k 個關鍵點,永琳需要破壞這些關鍵點才能走出竹林迷徑。
然而永琳打算將這k 個點編號記錄下來,然後隨機排列,按這個隨機的順序走過k 個點,但是兩點之間她只走最短路線。初始時永琳會施展一次魔法,將自己傳送到選定的k 個點中隨機後的第一個點。
現在永琳想知道,她走過路程的期望是多少,答案對998244353 取模。
注意,如果對期望不理解,題目最後有詳細解釋,請自行閱讀。
輸入
第一行一個數Case,表示測試點編號。(樣例的編號表示其滿足第Case 個測試點的性質)
下一行一個n,表示樹的點數。
下面 n-1 行,每行三個數ui,vi,wi,表示一條邊連接ui和vi,長度爲wi。
下面一行一個數k,表示關鍵點數。
下面一行k 個數,表示k 個關鍵點的編號。
輸出
一行一個數,表示答案(對998244353 取模)。
數據範圍
對於 100%的數據,保證1≤wi≤1041≤wi≤104。
測試點編號 |
n |
k |
特殊性質 |
1 |
≤10≤10 |
=1=1 |
無 |
2 |
|||
3 |
≤5≤5 |
||
4 |
|||
5 |
≤1000≤1000 |
≤7≤7 |
|
6 |
|||
7 |
≤105≤105 |
≤8≤8 |
|
8 |
|||
9 |
|||
10 |
≤16≤16 |
||
11 |
|||
12 |
|||
13 |
≤105≤105 |
||
14 |
|||
15 |
|||
16 |
|||
17 |
|||
18 |
≤106≤106 |
≤106≤106 |
是一條鏈 |
19 |
|||
20 |
|||
21 |
無 |
||
22 |
|||
23 |
|||
24 |
|||
25 |
【可能會用到的知識】
關於期望:
期望的定義:離散隨機變量的一切可能值與其對應的概率P 的乘積之和稱爲數學期望。
即: E(x)=∑P(x=k)×val(k)E(x)=∑P(x=k)×val(k)
其中E(x)是期望,P(x=k)是 x=k 發生的概率。
提示:答案必定可以表示成pqpq的形式,在模意義下,pq=p×q−1pq=p×q−1,其中q−1q−1是qq的逆元。
【提示】
讀入數據較大,請使用快速的讀入方式。
solution
本題求有序走完一個排列的期望長度。
考慮一個點i之前是j的貢獻:dist[i]+dist[j]-2*dist[lca]
算出所有點的dist的貢獻,lca的貢獻則做一次類似樹形dp的東西
一個點不同子樹互相走的數量就是它作lca的貢獻
#include<cstdio>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<cmath>
#define maxn 1000006
#define mod 998244353
#define ll long long
using namespace std;
int n,num,head[maxn],s[maxn],flag[maxn],t1,t2,t3,tot;
int Te,size[maxn];
ll ans,Sum,ny2,ny,d[maxn];
struct node{
int v,nex,w;
}e[maxn*2];
int read(){
int v=0,ch;
while(!isdigit(ch=getchar()));v=ch-48;
while(isdigit(ch=getchar()))v=(v<<1)+(v<<3)+ch-48;
return v;
}
void lj(int t1,int t2,int t3){
e[++tot].v=t2;e[tot].w=t3;e[tot].nex=head[t1];head[t1]=tot;
}
void dfs(int k,int fa,ll dist){
d[k]=dist;
for(int i=head[k];i;i=e[i].nex){
if(e[i].v==fa)continue;
dfs(e[i].v,k,dist+e[i].w);
size[k]+=size[e[i].v];
}
size[k]+=flag[k];
}
void dp(int k,int fa){
ll sum=0;
for(int i=head[k];i;i=e[i].nex){
if(e[i].v==fa)continue;
dp(e[i].v,k);
sum+=size[e[i].v];
}
ll cnt=0;
for(int i=head[k];i;i=e[i].nex){
if(e[i].v==fa)continue;
cnt=cnt+size[e[i].v]*(sum-size[e[i].v])%mod;
cnt=cnt%mod;
}
ans=(ans-2*d[k]%mod*cnt%mod)%mod;
if(flag[k]){
ans=(ans-4*d[k]%mod*sum%mod)%mod;
}
}
ll work(ll a,int Num){
ll Ans=1,p=a;
while(Num){
if(Num&1)Ans=Ans*p;
p=p*p;p%=mod;Ans%=mod;Num>>=1;
}
return Ans;
}
int main()
{
Te=read();n=read();
for(int i=1;i<n;i++){
t1=read();t2=read();t3=read();
lj(t1,t2,t3);lj(t2,t1,t3);
}
num=read();
for(int i=1;i<=num;i++){
s[i]=read();flag[s[i]]++;
}
dfs(1,0,(ll)0);
ny2=work(2,mod-2);
for(int i=1;i<=num;i++)Sum+=d[s[i]];
for(int i=1;i<=num;i++){
ans=(ans+d[s[i]]*(num-1))%mod+(Sum-d[s[i]])%mod;
ans=ans%mod;
}
dp(1,0);ans%=mod;
ans=ans*work(num,mod-2);
ans=(ans%mod+mod)%mod;
printf("%lld\n",ans);
return 0;
}