題意
給n個單詞,用它們構成新單詞,新單詞串合法的條件是:
1. 與某個原單詞相同
2. 是某個原單詞前綴+某個單詞後綴(前綴和後綴非空、前綴和後綴可以是整個單詞、兩個拼接字符串可以取自一個單詞)
統計不同的新單詞數量。
單詞長度
題解
先不考慮重複,
那麼每個前綴能對應所有的後綴,總單詞數目是
不同非空前綴數×不同非空後綴數+原單詞總數
然後思考重複,
對每個前綴,當其向後添加一個字符時,以這個字符爲首的後綴都算重了。
所以直觀的想法是:
對每個單詞建前綴樹,統計以x爲兒子節點的前綴數量。
對每個單詞建後綴樹,統計以x爲兒子節點的後綴數量。
兩者相乘就是重複數。
這個想法寫起來稍煩。
稍微思考一下可以發現,後綴樹是沒必要的,因爲每個“串的後綴的翻轉”可以看成“翻轉串的前綴”,所以對翻轉串建前綴樹,此時x的兒子變爲x的父親,即
對每個單詞建前綴樹,統計以x爲兒子節點的前綴數量。
對每個單詞翻轉建前綴樹,統計以x爲父親節點的前綴數量。
兩者相乘就是重複數。
(前綴樹可以想象成根到每個葉子是一個前綴,而“翻轉串”前綴樹可以想象成每個葉子到根是一個後綴)
在寫法上,注意到可以在每次添加字符串時候統計兩個“前綴數量”,觀察得到這兩者是統一的。。(即是說,直接在添加新節點的時候,sum[x]++即可)(注意綴長度爲1的情況下不要加(相當於從空串增加1字符))
然後到了這題比較坑的地方,不妨考慮
1
abc
這組數據
第一部分統計的是9個+1個新單詞(其中abc分別在a.bc,ab.c和abc三個地方重複出現)
第二部分統計的是1個(ab.c)
a.bc不會被統計的原因是:a.bc是從.abc來的,這不會被統計到。
而好在這種情況只會發生在單詞第一個字母上,只要把第一部分統計的新單詞數去掉即可。
然而,這仍然是錯的Orz,原因在於,當單詞長度爲1時,.a與a.都不會被統計到。。故答案會少所有長度爲1的單詞。。
故最後的算法是
A=
不同非空前綴數×不同非空後綴數+不同的單個字母單詞數
對每個單詞建前綴樹,統計以x爲兒子節點的前綴數量。
對每個單詞翻轉建前綴樹,統計以x爲父親節點的前綴數量。
兩者相乘就是重複數B。
A-B即是答案。
code
#include<algorithm>
#include<cstdio>
#include<cstring>
#include<iostream>
using namespace std;
typedef long long LL;
const int maxsigma=26;
const int maxnode=500010;
const int maxs=100;
int idx(char x) { return x - 'a'; }
struct TrieNode{
TrieNode *ch[maxsigma], *pre, *lst;
int v;
TrieNode(){ memset(ch,0,sizeof ch); lst=0; v=0; }
int calc(){ int x=(lst?lst->calc():0)+v; v=0; return x; }
};
struct Trie {
TrieNode trie[maxnode], *rot, *trieR;
int sum[maxsigma];
Trie(){
trieR=trie; rot=new(trieR++)TrieNode();
memset(sum,0,sizeof sum);
}
int size(){ return trieR-trie; }
void insert(char* s){
int n=strlen(s);
TrieNode* p=rot;
for(int i=0;i<n;++i){
int x=idx(s[i]);
if(!p->ch[x]){
p->ch[x]=new(trieR++)TrieNode();
if(i) ++sum[x];
}
p=p->ch[x];
}
}
};
int n;
bool p[26];
char s[maxs];
Trie pref,suf;
bool solve(){
if(!(scanf("%d",&n)==1))return 0;
new(&pref)Trie();
new(&suf)Trie();
memset(p,0,sizeof p);
LL res=0;
for(int i=0;i<n;++i){
scanf("%s",s); int m=strlen(s);
if(m==1&&!p[idx(s[0])]){
p[idx(s[0])]=1;
++res;
}
pref.insert(s);
reverse(s,s+m);
suf.insert(s);
}
// cout<<pref.siz e()<<' '<<suf.size()<<endl;
res+=(LL)(pref.size()-1)*(suf.size()-1);
for(int i=0;i<26;++i)res-=(LL)pref.sum[i]*suf.sum[i];
printf("%lld\n",res);
return 1;
}
int main(){
// freopen("D.in","r",stdin);
while(solve());
// for(;;);
return 0;
}