做《算法進階》時,我遇到了我從未涉及到的概率問題
研究了很久,終於學會概率dp和期望值
看一下這道題
【題意】
給定一個無向連通圖,其節點編號爲1到N,其邊的權值爲非負整數。
試求出一條從1號節點都N號節點的路徑,使得該路徑上經過的邊的權值的XOR和最大。
該路徑可以重複經過某些節點或邊,當一條邊在路徑中出現多次時,其權值在計算XOR和時也應被重複計算相應多的次數。
直接求解上述問題比較困難,於是你決定使用非完美算法。
具體來說,從1號節點開始,以相等的概率,隨機選擇與當前節點相關聯的某條邊,並沿着這條邊走到下一個節點,重複這個過程直到走到N號節點爲止,便得到一條從1號節點到N號節點的路徑。
顯然得到每條這樣的路徑的概率是不同的,並且每條這樣的路徑的XOR和也不一樣。
現在請你求出該算法得到的路徑的XOR和的期望值。
【輸入格式】
第一行包含兩個整數N和M,表示節點數和邊數。
接下來M行,每行包含三個整數u,v,w,表示存在一條邊(u,v),權值爲w。
圖中可能存在重邊或自環。
【輸出格式】
輸出包含一個實數,表示XOR和的期望值,結果保留三位小數。
【數據範圍】
2≤N≤100,
M≤10000,
1≤u,v≤N,
0≤w≤10^9
【輸入樣例】
2 2
1 1 2
1 2 3
【輸出樣例】
2.333
因爲不能直接將這個期望值當成一個整體來算,所以我們將它拆分。
期望值其實就是平均數,它=所有的(可能值X概率)
又因爲是異或運算,所以我們不妨將它拆分爲二進制的每一個位
先設f[i]爲i位是1的概率。
所以
解釋一下第一條公式:
先是枚舉每一條與u相連的邊(連到v)
如果這條邊二進制拆分以後所求的一位是1
就要加上v點0的概率,也就是1-f[v]
否則就是要加上v點1的概率,也就是f[v]
只有這樣,我們才能使u是1
參考代碼(有修改):
//Author:XuHt
#include <cmath>
#include <cstdio>
#include <vector>
#include <cstring>
#include <iostream>
using namespace std;
const int N = 106;
int n, m;
double a[N][N], b[N], ans;
vector<pair<int, int> > e[N];
void work() {
for (int i = 1; i < n; i++) {
/* int now = i;算法進階裏的標程有這一句,實際上可以不用
for (int j = i + 1; j < n; j++)
if (fabs(a[j][i]) > fabs(a[now][i])) now = j;
for (int j = 0; j <= n; j++) swap(a[i][j], a[now][j]);*/
for (int j = i + 1; j <= n; j++) {
double rate = a[j][i] / a[i][i];
for (int k = 0; k <= n; k++) a[j][k] = a[i][k] * rate - a[j][k];
}
}
for (int i = n; i; i--) {
for (int j = i + 1; j <= n; j++) a[i][0] -= a[i][j] * b[j];
b[i] = a[i][0] / a[i][i];
}
}
int main() {
cin >> n >> m;
for (int i = 1; i <= m; i++) {
int x, y, z;
scanf("%d %d %d", &x, &y, &z);
e[x].push_back(make_pair(y, z));
if (x != y) e[y].push_back(make_pair(x, z));
}
for (int i = 0; i < 31; i++) {
memset(a, 0, sizeof(a));
memset(b, 0, sizeof(b));
//高斯消元數組的構造(第三條公式除以dg[u])
for (int x = 1; x <= n; x++) a[x][x] = 1;
for (int x = 1; x < n; x++) {
int s = e[x].size();
for (int j = 0; j < s; j++) {
int y = e[x][j].first, z = e[x][j].second;
double w = 1.0 / s;
if ((z >> i) & 1) {
a[x][y] += w;
a[x][0] += w;
} else a[x][y] -= w;
}
}
work();
ans += b[1] * (1 << i);
}
printf("%.3f\n", ans);
return 0;
}