題目描述:
試題編號: 201709-4
試題名稱: 通信網絡
時間限制: 1.0s
內存限制: 256.0MB
問題描述:
問題描述某國的軍隊由N個部門組成,爲了提高安全性,部門之間建立了M條通路,每條通路只能單向傳遞信息,即一條從部門a到部門b的通路只能由a向b傳遞信息。信息可以通過中轉的方式進行傳遞,即如果a能將信息傳遞到b,b又能將信息傳遞到c,則a能將信息傳遞到c。一條信息可能通過多次中轉最終到達目的地。>
由於保密工作做得很好,並不是所有部門之間都互相知道彼此的存在。只有當兩個部門之間可以直接或間接傳遞信息時,他們才彼此知道對方的存在。部門之間不會把自己知道哪些部門告訴其他部門。
上圖中給了一個4個部門的例子,圖中的單向邊表示通路。部門1可以將消息發送給所有部門,部門4可以接收所有部門的消息,所以部門1和部門4知道所有其他部門的存在。部門2和部門3之間沒有任何方式可以發送消息,所以部門2和部門3互相不知道彼此的存在。
現在請問,有多少個部門知道所有N個部門的存在。或者說,有多少個部門所知道的部門數量(包括自己)正好是N。
輸入格式
輸入的第一行包含兩個整數N, M,分別表示部門的數量和單向通路的數量。所有部門從1到N標號。
接下來M行,每行兩個整數a, b,表示部門a到部門b有一條單向通路。
輸出格式
輸出一行,包含一個整數,表示答案。
樣例輸入
4 4
1 2
1 3
2 4
3 4
樣例輸出
2
樣例說明
部門1和部門4知道所有其他部門的存在。
評測用例規模與約定
對於30%的評測用例,1 ≤ N ≤ 10,1 ≤ M ≤ 20;
對於60%的評測用例,1 ≤ N ≤ 100,1 ≤ M ≤ 1000;
對於100%的評測用例,1 ≤ N ≤ 1000,1 ≤ M ≤ 10000。
題解:
看到這道題我一開始的做法是直接雙向dfs:
比如說我們看樣例:
首先正向對每個點dfs一遍:
從1開始,搜到2,4,3
從2開始,搜到4
從3開始,搜到4
從4開始,啥都搜不到
然後反向對每個點dfs一遍:
從1開始,啥都都不到
從2開始,搜到1
從3開始,搜到1
從4開始,搜到2,1,3統計一下
點1可以與其他的4個點有聯繫 sum++
點2只能與4,1有聯繫
點3只能與4,1有聯繫
點4可以與其他四個點有聯繫 sum++
所以sum=2,ok。
我試了幾個樣例,好像都沒錯,但是最後的得分只有10分,時間是>1s,應該是超時了。
10分代碼:
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <climits>
#include <cstring>
#include <string>
#include <algorithm>
#include <vector>
#include <deque>
#include <list>
#include <utility>
#include <set>
#include <map>
#include <stack>
#include <queue>
#include <bitset>
#include <iterator>
using namespace std;
typedef long long ll;
const int inf = 0x3f3f3f3f;
const ll INF = 0x3f3f3f3f3f3f3f3f;
const double PI = acos(-1.0);
const double E = exp(1.0);
const int MOD = 1e9+7;
const int MAX = 1e3+5;
int n,m;
int g1[MAX][MAX];// 正向圖
int g2[MAX][MAX];// 反向圖
int flag[MAX];// 點的訪問標記
void dfs_g1(int a)
{
//printf("正在dfs正向圖,此時位於:%d\n",a);
for(int b = 1; b <= n; b++)
{
if(b == a) continue;
if(g1[a][b] && !flag[b])
{
flag[b] = 1;
dfs_g1(b);
//flag[b] = 0;
}
}
}
void dfs_g2(int a)
{
//printf("正在dfs反向圖,此時位於:%d\n",a);
for(int b = 1; b <= n; b++)
{
if(b == a) continue;
if(g2[a][b] && !flag[b])
{
flag[b] = 1;
dfs_g2(b);
//flag[b] = 0;
}
}
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
while(cin >> n >> m)
{
memset(g1,0,sizeof(g1));
memset(g2,0,sizeof(g1));
int a,b;
for(int i = 1; i <= m; i++)
{
cin >> a >> b;
g1[a][b] = 1;
g2[b][a] = 1;
}
int sum = 0;
for(int a = 1; a <= n; a++)
{
int ans = 0;
for(int i = 1; i <= n; i++)
{
flag[i] = 0;
}
flag[a] = 1;
dfs_g1(a);
/* 統計從點a正向dfs可以遍歷到幾個點
ans = 0;
for(int i = 1; i <= n; i++)
{
if(flag[i])
{
ans++;
}
}
print("%d\n",ans);*/
dfs_g2(a);
ans = 0;
for(int i = 1; i <= n; i++)// 統計在從點a正向一次dfs的基礎上,又從點a反向一次dfs,可以經過多少個點
{
if(flag[i])
{
ans++;
}
}
//printf("反向dfs後,ans = %d\n",ans);
if(ans == n)
{
sum++;
}
}
cout << sum << endl;
}
return 0;
}
我還沒有考慮去怎麼優化,直接參考了別人的博客。
思路:
對n個點進行一次dfs,用一個二維數組記錄下 與起點 有聯繫的點(也就是dfs可以到的點),比如起點s和終點e有聯繫(正向dfs可以到達的),那麼終點e和起點s就也有聯繫(反向dfs可以到達的),直接標記connect[s][e] = connect[e][s] = 1就好了,所以就不用做兩遍dfs了,最後統計一下每個點相關聯的點如果有n個(包括自己),就sum++。
代碼:
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <climits>
#include <cstring>
#include <string>
#include <algorithm>
#include <vector>
#include <deque>
#include <list>
#include <utility>
#include <set>
#include <map>
#include <stack>
#include <queue>
#include <bitset>
#include <iterator>
using namespace std;
typedef long long ll;
const int inf = 0x3f3f3f3f;
const ll INF = 0x3f3f3f3f3f3f3f3f;
const double PI = acos(-1.0);
const double E = exp(1.0);
const int MOD = 1e9+7;
const int MAX = 1e3+5;
int n,m;
vector <int> g[MAX];
int connect[MAX][MAX];
int vis[MAX];
void dfs(int s,int cur)// s-起點 cur-當前點
{
vis[cur] = 1;
connect[s][cur] = connect[cur][s] = 1;// s點認識cur點,同時cur點也認識s點
for(int i = 0; i < (int)g[cur].size(); i++)// 遍歷當前點cur的相鄰點
{
if(!vis[g[cur][i]])
{
// 當前點cur沒有被訪問,可以從當前點cur的相鄰點g[cur][i]繼續往下訪問
dfs(s,g[cur][i]);
}
}
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
while(cin >> n >> m)
{
memset(connect,0,sizeof(connect));
int a,b;
for(int i = 1; i <= m; i++)
{
cin >> a >> b;
g[a].push_back(b);// 記錄下a相鄰的所有點
}
for(int i = 1; i <= n; i++)
{
memset(vis,0,sizeof(vis));
dfs(i,i);// dfs(起點,當前點)
}
int sum = 0;
for(int i = 1; i <= n; i++)
{
int tmp = 0;
for(int j = 1; j <= n; j++)
{
if(connect[i][j])
{
tmp++;
}
}
if(tmp == n)
{
sum++;
}
}
cout << sum << endl;
}
return 0;
}