0%

回文自动机学习笔记

回文自动机是一种可以处理回文符问题的优雅高效的数据结构。

URAL 1960 Palindromes and Super Abilities

题意

求各前缀的所有子串中的回文串种类。

分析

依次插入字符,每插入完一个字符,假如last指针所指的节点是新增的,那么答案加一。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include<bits/stdc++.h>
#define x first
#define y second
#define ok cout << "ok" << endl;
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef vector<int> vi;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
const long double PI = acos(-1.0);
const int INF = 0x3f3f3f3f;
const ll LINF = 9223372036854775807;
const double Eps = 1e-7;
const int N = 1e5+9;

/************************************************************/
/* 回文自动机:解决一类回文字符串问题(Tested 0 times)
* 时间复杂度:O(|S| * log(字符集个数))
*/
const int MAXN = 1e5 + 9;
const int NN = 26;

struct Palindromic_Tree {
int next[MAXN][NN] ;//next指针,next指针和字典树类似,指向的串为当前串两端加上同一个字符构成
int fail[MAXN] ;//fail指针,失配后跳转到fail指针指向的节点
int cnt[MAXN] ; //cnt[i]表示i表示的回文字符串在整个字符串中出现了多少次
int num[MAXN] ; //num[i]表示i表示的回文字符串中有多少个本质不同的字符串(包括本身)
int len[MAXN] ;//len[i]表示节点i表示的回文串的长度
int S[MAXN] ;//存放添加的字符
int last ;//指向上一个字符所在的节点,方便下一次add
int n ;//字符数组指针
int p ;//节点指针

int newnode ( int l ) {//新建节点
for ( int i = 0 ; i < NN ; ++ i ) next[p][i] = 0 ;
cnt[p] = 0 ;
num[p] = 0 ;
len[p] = l ;
return p ++ ;
}

void init () {//初始化
p = 0 ;
newnode ( 0 ) ;
newnode ( -1 ) ;
last = 0 ;
n = 0 ;
S[n] = -1 ;//开头放一个字符集中没有的字符,减少特判
fail[0] = 1 ;
}

int get_fail ( int x ) {//和KMP一样,失配后找一个尽量最长的
while ( S[n - len[x] - 1] != S[n] ) x = fail[x] ;
return x ;
}

void add ( int c ) {
c -= 'a' ;
S[++ n] = c ;
int cur = get_fail ( last ) ;//通过上一个回文串找这个回文串的匹配位置
if ( !next[cur][c] ) {//如果这个回文串没有出现过,说明出现了一个新的本质不同的回文串
int now = newnode ( len[cur] + 2 ) ;//新建节点
fail[now] = next[get_fail ( fail[cur] )][c] ;//和AC自动机一样建立fail指针,以便失配后跳转
next[cur][c] = now ;
num[now] = num[fail[now]] + 1 ;
}
last = next[cur][c] ;
cnt[last] ++ ;
}

void count () {
for ( int i = p - 1 ; i >= 0 ; -- i ) cnt[fail[i]] += cnt[i] ;
//父亲累加儿子的cnt,因为如果fail[v]=u,则u一定是v的子回文串!
}
}pt;
/************************************************************/

char s[N];

int main(void) {
if(fopen("in", "r")!=NULL) {freopen("in", "r", stdin); freopen("out", "w", stdout);}
scanf("%s", s);
int len = strlen(s);
pt.init();
int ans = 0;
for(int i = 0; i < len; i++) {
pt.add(s[i]);
if(pt.cnt[pt.last] == 1)
ans++;
printf("%d%c", ans, " \n"[i == len - 1]);
}
return 0;
}

Tsinsen A1280 最长双回文串

题意

一个字符串,在中间某个位置切开,能形成两个回文串,则称这个回文串为双回文串。现在需要求最长双回文串。

分析

令a[i]表示以i结束的最长回文串长度,b[i]表示从i开始的最长回文串长度。
每插入一个字符,就更新这两个数组,最后扫一遍,取a[i] + b[i+1]的最大值。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#include<bits/stdc++.h>
#define x first
#define y second
#define ok cout << "ok" << endl;
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef vector<int> vi;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
const long double PI = acos(-1.0);
const int INF = 0x3f3f3f3f;
const ll LINF = 9223372036854775807;
const double Eps = 1e-7;
const int N = 1e5+9;

int a[N], b[N];

/************************************************************/
/* 回文自动机:解决一类回文字符串问题(Tested 2 times)
* 时间复杂度:O(|S| * log(字符集个数))
*/
const int MAXN = 1e5 + 9;
const int NN = 26;

struct Palindromic_Tree {
int next[MAXN][NN] ;//next指针,next指针和字典树类似,指向的串为当前串两端加上同一个字符构成
int fail[MAXN] ;//fail指针,失配后跳转到fail指针指向的节点
int cnt[MAXN] ; //cnt[i]表示i表示的回文字符串在整个字符串中出现了多少次
int num[MAXN] ; //num[i]表示i表示的回文字符串中有多少个本质不同的字符串(包括本身)
int len[MAXN] ;//len[i]表示节点i表示的回文串的长度
int S[MAXN] ;//存放添加的字符
int last ;//指向上一个字符所在的节点,方便下一次add
int n ;//字符数组指针
int p ;//节点指针

int newnode ( int l ) {//新建节点
for ( int i = 0 ; i < NN ; ++ i ) next[p][i] = 0 ;
cnt[p] = 0 ;
num[p] = 0 ;
len[p] = l ;
return p ++ ;
}

void init () {//初始化
p = 0 ;
newnode ( 0 ) ;
newnode ( -1 ) ;
last = 0 ;
n = 0 ;
S[n] = -1 ;//开头放一个字符集中没有的字符,减少特判
fail[0] = 1 ;
}

int get_fail ( int x ) {//和KMP一样,失配后找一个尽量最长的
while ( S[n - len[x] - 1] != S[n] ) x = fail[x] ;
return x ;
}

// 插入的是字符
void add (int id, int c ) {
c -= 'a' ;
S[++ n] = c ;
int cur = get_fail ( last ) ;//通过上一个回文串找这个回文串的匹配位置
if ( !next[cur][c] ) {//如果这个回文串没有出现过,说明出现了一个新的本质不同的回文串
int now = newnode ( len[cur] + 2 ) ;//新建节点
fail[now] = next[get_fail ( fail[cur] )][c] ;//和AC自动机一样建立fail指针,以便失配后跳转
next[cur][c] = now ;
num[now] = num[fail[now]] + 1 ;
}
last = next[cur][c] ;
a[id] = len[last];
b[id - len[last] + 1] = max(b[id - len[last] + 1], len[last]);
cnt[last] ++ ;
}

void count () {
for ( int i = p - 1 ; i >= 0 ; -- i ) cnt[fail[i]] += cnt[i] ;
//父亲累加儿子的cnt,因为如果fail[v]=u,则u一定是v的子回文串!
}
}pt;
/************************************************************/

char s[N];

int main(void) {
if(fopen("in", "r")!=NULL) {freopen("in", "r", stdin); freopen("out", "w", stdout);}
scanf("%s", s);
int len = strlen(s);
pt.init();
for(int i = 0; i < len; i++) {
pt.add(i, s[i]);
}
int ans = 2;
for(int i = 0; i < len - 1; i++) {
ans = max(ans, a[i] + b[i + 1]);
}
printf("%d\n", ans);

return 0;
}

Tsinsen A1255 拉拉队排练

题意

求前k大奇数长度的回文串的乘积。

分析

先建回文自动机,然后dfs奇根节点,找出奇数长度的回文串的长度和个数,放到容器,排个序贪心取出来即可。注意不能直接递归dfs,会爆栈,需要用stack。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#pragma comment(linker, "/STACK:1024000000,1024000000") 

#include<bits/stdc++.h>
#define x first
#define y second
#define ok cout << "ok" << endl;
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef vector<int> vi;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
const long double PI = acos(-1.0);
const int INF = 0x3f3f3f3f;
const ll LINF = 9223372036854775807;
const double Eps = 1e-7;
const int N = 1e6+9;

const int mod = 19930726;


/************************************************************/
/* 回文自动机:解决一类回文字符串问题(Tested 4 times)
* 时间复杂度:O(|S| * log(字符集个数))
*/
const int MAXN = 1e6 + 9;
const int NN = 26; //字符集个数

struct Palindromic_Tree {
int next[MAXN][NN];//next指针,next指针和字典树类似,指向的串为当前串两端加上同一个字符构成
int fail[MAXN];//fail指针,失配后跳转到fail指针指向的节点
int cnt[MAXN]; //cnt[i]表示i表示的回文字符串在整个字符串中出现了多少次
int num[MAXN]; //num[i]表示i表示的回文字符串中有多少个本质不同的字符串(包括本身)
int len[MAXN];//len[i]表示节点i表示的回文串的长度
int S[MAXN];//存放添加的字符
int last;//指向上一个字符所在的节点,方便下一次add
int n;//字符数组指针
int p;//节点指针

int newnode(int l) {//新建节点
for(int i = 0; i < NN; i++) next[p][i] = 0;
cnt[p] = 0;
num[p] = 0;
len[p] = l;
return p++;
}

void init() {//初始化
p = 0;
newnode(0);
newnode(-1);
last = 0;
n = 0;
S[n] = -1;//开头放一个字符集中没有的字符,减少特判
fail[0] = 1;
}

int get_fail(int x) {//和KMP一样,失配后找一个尽量最长的
while(S[n - len[x] - 1] != S[n]) x = fail[x];
return x ;
}

// 插入的是字符
void add(int c) {
c -= 'a';
S[++ n] = c;
int cur = get_fail(last);//通过上一个回文串找这个回文串的匹配位置
if(!next[cur][c]) {//如果这个回文串没有出现过,说明出现了一个新的本质不同的回文串
int now = newnode(len[cur] + 2);//新建节点
fail[now] = next[get_fail(fail[cur])][c];//和AC自动机一样建立fail指针,以便失配后跳转
next[cur][c] = now;
num[now] = num[fail[now]] + 1;
}
last = next[cur][c];
cnt[last]++;
}

void count () {
for(int i = p - 1; i >= 0; i--) cnt[fail[i]] += cnt[i];
//父亲累加儿子的cnt,因为如果fail[v]=u,则u一定是v的子回文串!
}
}pt;
/************************************************************/

char s[N];
int n;
ll k;
pii a[N];
vector<pii> v;
ll sum;
queue<int>que;

void dfs(int u) {
que.push(u);
while(!que.empty()) {
int u = que.front(); que.pop();
for(int i = 0; i < 26; i++) {
int t = pt.next[u][i];
if(t) {
sum += pt.cnt[t];
que.push(t);
//cout << t << endl;
v.push_back(pii(pt.len[t], pt.cnt[t]));
//dfs(t);
}
}
}
}

bool cmp(pii a, pii b) {
return a.x > b.x;
}

ll ksm(ll x, ll y) {
ll ans = 1;
while(y) {
if(y & 1) (ans *= x) %= mod;
(x *= x) %= mod;
y >>= 1;
}
return ans;
}

int main(void) {
if(fopen("in", "r")!=NULL) {freopen("in", "r", stdin); freopen("out", "w", stdout);}
scanf("%d%lld", &n, &k);
scanf("%s", s);
n = strlen(s);
pt.init();
for(int i = 0; i < n; i++) {
pt.add(s[i]);
}
pt.count();
dfs(1);
if(sum < k) {
printf("-1\n");
}
else {
sort(v.begin(), v.end(), cmp);
ll ans = 1;
for(auto i: v) {
ll num = min(1LL * i.y, k);
(ans *= ksm(i.x, num)) %= mod;
k -= i.y;
if(k <= 0) break;
}
printf("%lld\n", ans);
}

return 0;
}

Tsinsen A1393 Palisection

题意

求相交回文串的对数。

分析

直接求相交回文串的对数的话,不好求。经过思考,发现所有回文串对数以及不相交回文串的对数比较好求,而这两者相减就是答案,得解。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#include<bits/stdc++.h>
#define x first
#define y second
#define ok cout << "ok" << endl;
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef vector<int> vi;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
const long double PI = acos(-1.0);
const int INF = 0x3f3f3f3f;
const ll LINF = 9223372036854775807;
const double Eps = 1e-7;
const int N = 2e6+9;

const int mod = 51123987;

/************************************************************/
/* 回文自动机:解决一类回文字符串问题(Tested 5 times)
* 时间复杂度:O(|S| * log(字符集个数))
*/
const int MAXN = 2e6 + 9;
const int NN = 26; //字符集个数

struct Palindromic_Tree {
int next[MAXN][NN];//next指针,next指针和字典树类似,指向的串为当前串两端加上同一个字符构成
int fail[MAXN];//fail指针,失配后跳转到fail指针指向的节点
int cnt[MAXN]; //cnt[i]表示i表示的回文字符串在整个字符串中出现了多少次
int num[MAXN]; //num[i]表示i表示的回文字符串中有多少个本质不同的字符串(包括本身)
int len[MAXN];//len[i]表示节点i表示的回文串的长度
int S[MAXN];//存放添加的字符
int last;//指向上一个字符所在的节点,方便下一次add
int n;//字符数组指针
int p;//节点指针

int newnode(int l) {//新建节点
for(int i = 0; i < NN; i++) next[p][i] = 0;
cnt[p] = 0;
num[p] = 0;
len[p] = l;
return p++;
}

void init() {//初始化
p = 0;
newnode(0);
newnode(-1);
last = 0;
n = 0;
S[n] = -1;//开头放一个字符集中没有的字符,减少特判
fail[0] = 1;
}

int get_fail(int x) {//和KMP一样,失配后找一个尽量最长的
while(S[n - len[x] - 1] != S[n]) x = fail[x];
return x ;
}

// 插入的是字符
int add(int c) {
c -= 'a';
S[++ n] = c;
int cur = get_fail(last);//通过上一个回文串找这个回文串的匹配位置
if(!next[cur][c]) {//如果这个回文串没有出现过,说明出现了一个新的本质不同的回文串
int now = newnode(len[cur] + 2);//新建节点
fail[now] = next[get_fail(fail[cur])][c];//和AC自动机一样建立fail指针,以便失配后跳转
next[cur][c] = now;
num[now] = num[fail[now]] + 1;
}
last = next[cur][c];
cnt[last]++;
return num[last];
}

void count () {
for(int i = p - 1; i >= 0; i--) cnt[fail[i]] += cnt[i];
//父亲累加儿子的cnt,因为如果fail[v]=u,则u一定是v的子回文串!
}
}pt;
/************************************************************/

char s[N];
int n;
ll sum[N], ans;

int main(void) {
if(fopen("in", "r")!=NULL) {freopen("in", "r", stdin); freopen("out", "w", stdout);}
scanf("%d%s", &n, s);
pt.init();
for(int i = 0; i < n; i++) {
(sum[i+1] = sum[i] + pt.add(s[i])) %= mod;
}
(ans = 1LL * sum[n] * (sum[n] - 1) / 2) %= mod;
reverse(s, s + n);
pt.init();
ll t = 0;
for(int i = 0; i < n; i++) {
t = pt.add(s[i]);
ans -= t * sum[n - i - 1];
(ans += mod) %= mod;
}
printf("%lld\n", ans);

return 0;
}

Gym 100548G The Problem to Slow Down You

题意

给两个字符串,求两个字符串中相同回文串的对数。

分析

首先分别给两个字符串建立回文自动机,然后分别dfs一下奇偶根节点,累加两个回文自动机相同位置的节点的cnt乘积即可。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#include<bits/stdc++.h>
#define x first
#define y second
#define ok cout << "ok" << endl;
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef vector<int> vi;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
const long double PI = acos(-1.0);
const int INF = 0x3f3f3f3f;
const ll LINF = 9223372036854775807;
const double Eps = 1e-7;
const int N = 2e5+9;


/************************************************************/
/* 回文自动机:解决一类回文字符串问题(Tested 4 times)
* 时间复杂度:O(|S| * log(字符集个数))
*/
const int MAXN = 2e5 + 9;
const int NN = 26; //字符集个数

struct Palindromic_Tree {
int next[MAXN][NN];//next指针,next指针和字典树类似,指向的串为当前串两端加上同一个字符构成
int fail[MAXN];//fail指针,失配后跳转到fail指针指向的节点
int cnt[MAXN]; //cnt[i]表示i表示的回文字符串在整个字符串中出现了多少次
int num[MAXN]; //num[i]表示i表示的回文字符串中有多少个本质不同的字符串(包括本身)
int len[MAXN];//len[i]表示节点i表示的回文串的长度
int S[MAXN];//存放添加的字符
int last;//指向上一个字符所在的节点,方便下一次add
int n;//字符数组指针
int p;//节点指针

int newnode(int l) {//新建节点
for(int i = 0; i < NN; i++) next[p][i] = 0;
cnt[p] = 0;
num[p] = 0;
len[p] = l;
return p++;
}

void init() {//初始化
p = 0;
newnode(0);
newnode(-1);
last = 0;
n = 0;
S[n] = -1;//开头放一个字符集中没有的字符,减少特判
fail[0] = 1;
}

int get_fail(int x) {//和KMP一样,失配后找一个尽量最长的
while(S[n - len[x] - 1] != S[n]) x = fail[x];
return x ;
}

// 插入的是字符
void add(int c) {
c -= 'a';
S[++ n] = c;
int cur = get_fail(last);//通过上一个回文串找这个回文串的匹配位置
if(!next[cur][c]) {//如果这个回文串没有出现过,说明出现了一个新的本质不同的回文串
int now = newnode(len[cur] + 2);//新建节点
fail[now] = next[get_fail(fail[cur])][c];//和AC自动机一样建立fail指针,以便失配后跳转
next[cur][c] = now;
num[now] = num[fail[now]] + 1;
}
last = next[cur][c];
cnt[last]++;
}

void count () {
for(int i = p - 1; i >= 0; i--) cnt[fail[i]] += cnt[i];
//父亲累加儿子的cnt,因为如果fail[v]=u,则u一定是v的子回文串!
}
}pt[2];
/************************************************************/

char s[2][N];
int len[2], kase;
ll ans;

void dfs(int u0, int u1) {
for(int i = 0; i < 26; i++) {
int t0 = pt[0].next[u0][i];
int t1 = pt[1].next[u1][i];
if(t0 && t1) {
ans += 1LL * pt[0].cnt[t0] * pt[1].cnt[t1];
dfs(t0, t1);
}
}
}

int main(void) {
if(fopen("in", "r")!=NULL) {freopen("in", "r", stdin); freopen("out", "w", stdout);}
int T;
scanf("%d", &T);
while(T--) {
scanf("%s%s", s[0], s[1]);
for(int k = 0; k < 2; k++) {
len[k] = strlen(s[k]);
pt[k].init();
for(int i = 0; i < len[k]; i++) {
pt[k].add(s[k][i]);
}
pt[k].count();
}
ans = 0;
dfs(0, 0);
dfs(1, 1);
printf("Case #%d: %lld\n", ++kase, ans);
}

return 0;
}