CF504E Misha and LCP on Tree 题解
Visits: 320
题意
- 给定一棵 $n$ 个节点的树,每个节点有一个小写字母。
- 有 $m$ 组询问,每组询问为树上 $a \to b$ 和 $c \to d$ 组成的字符串的最长公共前缀。
- $n \le 3 \times 10^5$,$m \le 10^6$。
题解
首先,根据 hash 值具有可减性,可以通过预处理每个点到根以及根到每个点的 hash 值,$\mathcal O(1)$ 求出任意一条没有拐点的路径的 hash 值。
考虑树剖,我们可以将 $a \to b$ 和 $c \to d$ 的路径分别拆成 $\mathcal O(\log n)$ 个区间,然后每次消掉一个区间。
如果消不掉这个区间,则在这个区间上二分找到答案。
总时间复杂度 $\mathcal O(n + m \log n)$。
另外,在 CF 上最好写双 hash,且不能采用自然溢出,否则大概率被卡。
代码
#define Hash pair <modint, modint>
const int N = 3e5 + 7, mod = 1e9 + 7;
const Hash B = mp((modint)131, (modint)13331);
int n, m, d[N], f[N], s[N], son[N], dfn[N], rnk[N], top[N], num;
Hash p[N], h1[N], h2[N];
char c[N];
vi e[N];
inline Hash operator + (Hash a, Hash b) {
return mp(a.fi + b.fi, a.se + b.se);
}
inline Hash operator - (Hash a, Hash b) {
return mp(a.fi - b.fi, a.se - b.se);
}
inline Hash operator * (Hash a, Hash b) {
return mp(a.fi * b.fi, a.se * b.se);
}
inline Hash H(bool o, int x, int k) {
if (o) return h2[x-k+1] - h2[x+1] * p[k];
return h1[x+k-1] - h1[x-1] * p[k];
}
void dfs(int x) {
s[x] = 1;
for (auto y : e[x])
if (y != f[x]) {
f[y] = x, d[y] = d[x] + 1, dfs(y), s[x] += s[y];
if (s[y] > s[son[x]]) son[x] = y;
}
}
void dfs(int x, int p) {
top[x] = p, dfn[x] = ++num, rnk[num] = x;
if (son[x]) dfs(son[x], p);
for (auto y : e[x]) if (y != f[x] && y != son[x]) dfs(y, y);
}
inline int lca(int x, int y) {
while (top[x] != top[y]) {
if (d[top[x]] > d[top[y]]) swap(x, y);
y = f[top[y]];
}
if (d[x] > d[y]) swap(x, y);
return x;
}
inline vector <pi> get(int x, int y) {
int z = lca(x, y);
vector <pi> o, w;
while (d[top[x]] > d[z]) o.pb(mp(x, top[x])), x = f[top[x]];
o.pb(mp(x, z));
while (d[top[y]] > d[z]) w.pb(mp(top[y], y)), y = f[top[y]];
if (y != z) w.pb(mp(son[z], y));
while (w.size()) o.pb(w.back()), w.pop_back();
return o;
}
int main() {
rd(n), rds(c, n), p[0] = mp(1, 1), p[1] = B;
for (int i = 1, x, y; i < n; i++)
rd(x), rd(y), e[x].pb(y), e[y].pb(x), p[i+1] = p[i] * B;
d[1] = 1, dfs(1), dfs(1, 1);
for (int i = 1; i <= n; i++)
h1[i] = h1[i-1] * B + mp(c[rnk[i]], c[rnk[i]]),
h2[n+1-i] = h2[n+2-i] * B + mp(c[rnk[n+1-i]], c[rnk[n+1-i]]);
rd(m);
while (m--) {
int a, b, c, d, ans = 0;
ui s = 0, t = 0;
rd(a), rd(b), rd(c), rd(d);
vector <pi> f = get(a, b), g = get(c, d);
while (s < f.size() && t < g.size()) {
int df1 = dfn[f[s].fi], df2 = dfn[f[s].se];
int dg1 = dfn[g[t].fi], dg2 = dfn[g[t].se];
bool of = df1 > df2, og = dg1 > dg2;
int lf = (of ? df1 - df2 : df2 - df1) + 1;
int lg = (og ? dg1 - dg2 : dg2 - dg1) + 1;
int len = min(lf, lg);
Hash hf = H(of, df1, len);
Hash hg = H(og, dg1, len);
if (hf == hg) {
if (len == lf) ++s;
else f[s].fi = rnk[df1+(of?-1:1)*len];
if (len == lg) ++t;
else g[t].fi = rnk[dg1+(og?-1:1)*len];
ans += len;
} else {
int l = 1, r = len;
while (l < r) {
int mid = (l + r) >> 1;
hf = H(of, df1, mid);
hg = H(og, dg1, mid);
if (hf == hg) l = mid + 1;
else r = mid;
}
ans += l - 1; break;
}
}
print(ans);
}
return 0;
}