先给个链接:Problem - E - Codeforces。
这题应该可以用哈希,但是容易被hack,我交了好几次都没过,就是在20多那块不好弄。
先来讲一下思路。显然,我们需要把这个比较抽象的问题具体化,C(a,b)的长度,其实就是a+b(拼接)的长度减去翻转后的a与b的最长相同前缀的二倍,具体推导过程恕不展示。
那么考虑实现,我们先读入每一个字符串并存储,之后循环遍历每个字符串,用某种数据结构或算法来记录每个前缀在所有字符串里出现的次数,然后计算出以该字符串作为‘a’的所有C(a,b)之和,但是这里我们不考虑删除结尾和开头的节点,因为暂时无法计算(其实只需要用该字符串长度乘上n*2即可),最后在遍历一次每个字符串,并把它翻转,依次遍历这个字符串的每个前缀,在答案的基础上减去该前缀出现的次数的二倍(a删掉一些,b也删掉一些)。首先我们想到的使用一个unordered_map<string,int>来记录每个字符串作为前缀出现了多少次,进一步想,如果用string超时,那就用哈希呗,但是过得很艰难……(这数据真的是[文字][文字],我[文字][文字]服了)
所以就要考虑用字典树,本质上字典树就是通过树形结构维护前缀来存储所有字符串,我们只要在每个节点上记录该前缀出现个数,再按照上述方法直接去做就可以了。
好了,我知道你们在等代码,大概率已经把上面的内容全都划走了,那么我直接奉上!
#include<bits/stdc++.h>
#define N 1100000
#define S 1100000
using namespace std;
struct TrieNode{
map<char,unsigned long long>son;
unsigned long long sum;
};
TrieNode trie[S]={};
string s[N]={};
unsigned long long ans=0,cnt=0,n=0;
int main(){
cnt++;
cin>>n;
for(unsigned long long i=1;i<=n;i++){
cin>>s[i];
}
for(unsigned long long i=1;i<=n;i++){
unsigned long long now=1;
for(char j:s[i]){
if(trie[now].son[j]!=0){
trie[trie[now].son[j]].sum++;
}else{
cnt++;
trie[now].son[j]=cnt;
trie[cnt].sum++;
}
now=trie[now].son[j];
}
ans+=s[i].size()*n;
}
for(unsigned long long i=1;i<=n;i++){
reverse(s[i].begin(),s[i].end());
unsigned long long now=1;
for(char j:s[i]){
if(trie[now].son[j]!=0){
ans-=trie[trie[now].son[j]].sum;
now=trie[now].son[j];
}else{
break;
}
}
}
cout<<ans*2<<endl;
return 0;
}
拒绝抄袭!(复制粘贴也算抄袭)
顺便说一嘴,这里我们发现前面加的和后面减得都是乘以二,所以可以先不乘上二,最后一起乘就行了。
如果不会字典树,我再粘一个哈希的AC代码。
#include<bits/stdc++.h>
#define N 1100000
using namespace std;
const long long p=1e9+7,x=30;
unordered_map<unsigned long long,unordered_map<unsigned long long,int>>ump={};
string s[N]={};
long long ans=0;
unsigned long long n=0;
int main(){
cin>>n;
for(unsigned long long i=1;i<=n;i++){
cin>>s[i];
}
for(unsigned long long i=1;i<=n;i++){
unsigned long long h1=0,h2=0;
for(char j:s[i]){
h1=h1*x+(j-'a'+1);
h2=h2*x+(j-'a'+1);
h1%=p;
ump[h1][h2]++;
}
ans+=s[i].size()*n*2;
}
for(unsigned long long i=1;i<=n;i++){
reverse(s[i].begin(),s[i].end());
unsigned long long h1=0,h2=0;
for(char j:s[i]){
h1=h1*x+(j-'a'+1);
h2=h2*x+(j-'a'+1);
h1%=p;
ans-=ump[h1][h2]*2;
}
}
cout<<ans;
return 0;
}