題面
給定一張有向無環圖,定義重要節點x爲:對於圖中每一個點y,都可以從y出發到x或者從x出發到y.
次重要節點爲刪除某個節點後能滿足上面條件的節點(不包括重要節點),求重要節點和次重要節點一共有多少個
\(2<=n, m<= 300000\)
題解
我們考慮這樣兩個性質:
- 當我們在拓撲排序的時候,同一時刻出現在隊列中的點互不可達。
- 除去已經排序完的點和當前在隊列的點,剩下的點全都是直接or間接由當前隊列中的點拓展出來的。
我們記f[i]爲點i可以到達or可以到達點i的點數之和。那麼題目要求的點就是滿足f[i] >= n - 2的點。
當隊列裏點數大於等於3的時候,由於互不可達,因此他們必然不是我們要求的點
當隊列裏點數只有1的時候,由性質2可以得出,剩下的點都可以加入f[x]
當隊列裏點數爲2時,假設點爲x和y,那麼如果後續有一個點只能被x到達,那麼它顯然不能對y產生貢獻。打個標記就行
一個細節:
如果x在y前面,那麼我們應該用y來判x,因爲當前在更新x的f值.而當x被更新完後,只能被x到達的點已經加入隊列了。
考慮如何用y來判x:
我們定義內部點爲y可達的所有點,內部邊爲兩個端點都是內部點的邊
如果y的某個鄰居入度不是1,那麼說明x有可能直接or間接指向它。
如果x不能指向它,說明這個鄰居是被內部點指向了,那麼我們沿着內部邊反向走,一定可以走到一個(也可能是多個)點,滿足這個點沒有內部點連向它,不然的話就出現了環,不符合題意。那麼對於這樣的點,如果x不能連向它,說明它只有一個入度且入度爲y(不然的話xy都不連向它,內部點也不連向它,它早入度爲0進隊了。),那麼y會根據這個點,給x打上標記。
#include<bits/stdc++.h>
using namespace std;
#define R register int
#define AC 301000
#define ac 601000
int n, m, ans;
int f[AC], in[AC];
int Head[AC], Next[ac], date[ac], tot;
int q[AC], head, tail;
bool z[AC];
struct node{
int f, w;
}way[AC];
inline int read()
{
int x = 0;char c = getchar();
while(c > '9' || c < '0') c = getchar();
while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return x;
}
inline void add(int f, int w){
date[++ tot] = w, Next[tot] = Head[f], Head[f] = tot, in[w] ++;
}
void pre()
{
n = read(), m = read();
for(R i = 1; i <= m; i ++)
{
way[i].f = read(), way[i].w = read();
add(way[i].f, way[i].w);
}
}
void get(int cnt)//tmp == 2時,前者若不合法,在get中被篩出
{
int x = q[head - 1], y = q[head];
bool flag = 0;
for(R i = Head[y]; i; i = Next[i])
if(in[date[i]] == 1) {flag = 1; break;}
if(flag) z[x] = 1;
else f[x] += n - cnt;
}
void t_sort()
{
int cnt = 0;
head = 1, tail = 0;
for(R i = 1; i <= n; i ++)
if(!in[i]) q[++ tail] = i, ++ cnt;
while(head <= tail)
{
int x = q[head ++];
int tmp = tail - head + 1 + 1;
if(tmp == 1) f[x] += n - cnt;
if(tmp == 2) get(cnt);
for(R i = Head[x]; i; i = Next[i])
{
int now = date[i];
if(!(-- in[now])) q[++ tail] = now, ++ cnt;//tmp == 2時,後者若不合法,在這裏保證他的f值正確
}
}
}
void work()
{
t_sort();
tot = 0, memset(Head, 0, (n + 2) * 4), memset(in, 0, (n + 2) * 4);
for(R i = 1; i <= m; i ++) add(way[i].w, way[i].f);
t_sort();
for(R i = 1; i <= n; i ++)
if(!z[i] && f[i] >= n - 2) ++ ans;
printf("%d\n", ans);
}
int main()
{
// freopen("in.in", "r", stdin);
pre();
work();
return 0;
}