动态树分治

动态点分治

动态点分治用来解决 带点权/边权修改 的树上路径信息统计问题。

点分树

回顾点分治的计算过程。

对于一个结点 x 来说,其子树中的简单路径包括两种:经过结点 x 的,由一条或两条从 x 出发的路径组成的;和不经过结点 x 的,即已经包含在其所有儿子结点子树中的路径。

对于一个子树中简单路径的计算,我们选择一个分治中心 rt ,计算经过该节点的子树中路径的信息,然后对于其每个儿子结点,将删去 rt 后该点所在连通块作为一个子树,递归计算。选择的分治中心点可以构成一个树形结构,称为 点分树 。我们发现,计算点分树中同一层的结点所代表的连通块(即以该结点为分治中心的连通块)的大小总和是 O(n) 的。这意味着,点分治的时间复杂度是与点分树的深度相关的,若点分树的深度为 h ,则点分治的复杂度为 O(nh)

可以证明,当我们每次选择连通块的重心作为分治中心的时候,点分树的深度最小,为 O(\log n) 的。这样,我们就可以在 O(n\log n) 的时间复杂度内统计树上 O(n^2) 条路径的信息了。

由于树的形态在动态点分治的过程中不会改变,所以点分树的形态在动态点分治的过程中也不会改变。

下面给出求点分树的参考代码:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
void calcsiz(int x, int f) {
  siz[x] = 1;
  maxx[x] = 0;
  for (int j = h[x]; j; j = nxt[j])
    if (p[j] != f && !vis[p[j]]) {
      calcsiz(p[j], x);
      siz[x] += siz[p[j]];
      maxx[x] = max(maxx[x], siz[p[j]]);
    }
  maxx[x] =
      max(maxx[x], sum - siz[x]);  // maxx[x] 表示以 x 为根时的最大子树大小
  if (maxx[x] < maxx[rt])
    rt = x;  //这里不能写 <= ,保证在第二次 calcsiz 时 rt 不改变
}
void pre(int x) {
  vis[x] = true;  //表示在之后的过程中不考虑 x 这个点
  for (int j = h[x]; j; j = nxt[j])
    if (!vis[p[j]]) {
      sum = siz[p[j]];
      rt = 0;
      maxx[rt] = inf;
      calcsiz(p[j], -1);
      calcsiz(rt, -1);  //计算两次,第二次求出以 rt 为根时的各子树大小
      fa[rt] = x;
      pre(rt);  //记录点分树上的父亲
    }
}
int main() {
  sum = n;
  rt = 0;
  maxx[rt] = inf;
  calcsiz(1, -1);
  calcsiz(rt, -1);
  pre(rt);
}

实现修改

在查询和修改的时候,我们在点分树上暴力跳父亲修改。由于点分树的深度最多是 O(\log n) 的,所以这样做复杂度能得到保证。

在动态点分治的过程中,需要一个结点到其点分树上的祖先的距离等其他信息,由于一个点最多有 O(\log n) 个祖先,我们可以在计算点分树时额外计算深度 dep[x] 或使用 LCA,预处理出这些距离或实现实时查询。 注意 :一个结点到其点分树上的祖先的距离不一定递增,不能累加!

在动态点分治的过程中,一个结点在其点分树上的祖先结点的信息中可能会被重复计算,这是我们需要消去重复部分的影响。一般的方法是对于一个连通块用两种方式记录:一个是其到分治中心的距离信息,另一个是其到点分树上分治中心父亲的距离信息。这一部分内容将在例题中得到展现。

例题luogu P2056[ZJOI2007]捉迷藏

给定一棵有 n 个结点的树,初始时所有结点都是黑色的。你需要实现以下两种操作:

  1. 反转一个结点的颜色(白变黑,黑变白);
  2. 询问树上两个最远的黑点的距离。

    n\le 10^5,m\le 5\times 10^5

求出点分树,对于每个结点 x 维护两个 可删堆dist[x] 存储结点 x 代表的连通块中的所有黑点到 x 的距离信息, ch[x] 表示结点 x 在点分树上的所有儿子和它自己中的黑点到 x 的距离信息,由于本题贪心的求答案方法,且两个来自于同一子树的路径不能成为一条完成的路径,我们只在这个堆中插入其自己的值和其每个子树中的最大值。我们发现, ch[x] 中最大的两个值(如果没有两个就是所有值)的和就是分治时分支中心为 x 时经过结点 x 的最长黑端点路径。我们可以用可删堆 ans 存储所有结点的答案,这个堆中的最大值就是我们所求的答案。

我们可以根据上面的定义维护 dist[x],ch[x],ans 这些可删堆。当 dist[x] 中的值发生变化时,我们也可以在 O(\log n) 的时间复杂度内维护 ch[x],ans

现在我们来看一下,当我们反转一个点的颜色时, dist[x] 值会发生怎样的改变。当结点原来是黑色时,我们要进行的是删除操作;当结点原来是白色时,我们要进行的是插入操作。

假如我们要反转结点 x 的颜色。对于其所有祖先 u ,我们在 dist[u] 中插入或删除 dist(x,u) ,并同时维护 ch[x],ans 的值。特别的,我们要在 ch[x] 中插入或删除值 0

参考代码:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <queue>
using namespace std;
const int maxn = 100010;
const int inf = 2e9;
int n, a, b, m, x, col[maxn];
// 0 off 1 on
char op;
int cur, h[maxn * 2], nxt[maxn * 2], p[maxn * 2];
void add_edge(int x, int y) {
  cur++;
  nxt[cur] = h[x];
  h[x] = cur;
  p[cur] = y;
}
bool vis[maxn];
int rt, sum, siz[maxn], maxx[maxn], fa[maxn], dep[maxn];
void calcsiz(int x, int f) {
  siz[x] = 1;
  maxx[x] = 0;
  for (int j = h[x]; j; j = nxt[j])
    if (p[j] != f && !vis[p[j]]) {
      calcsiz(p[j], x);
      siz[x] += siz[p[j]];
      maxx[x] = max(maxx[x], siz[p[j]]);
    }
  maxx[x] = max(maxx[x], sum - siz[x]);
  if (maxx[x] < maxx[rt]) rt = x;
}
struct heap {
  priority_queue<int> A, B;  // heap=A-B
  void insert(int x) { A.push(x); }
  void erase(int x) { B.push(x); }
  int top() {
    while (!B.empty() && A.top() == B.top()) A.pop(), B.pop();
    return A.top();
  }
  void pop() {
    while (!B.empty() && A.top() == B.top()) A.pop(), B.pop();
    A.pop();
  }
  int top2() {
    int t = top(), ret;
    pop();
    ret = top();
    A.push(t);
    return ret;
  }
  int size() { return A.size() - B.size(); }
} dist[maxn], ch[maxn], ans;
void dfs(int x, int f, int d, heap& y) {
  y.insert(d);
  for (int j = h[x]; j; j = nxt[j])
    if (p[j] != f && !vis[p[j]]) dfs(p[j], x, d + 1, y);
}
void pre(int x) {
  vis[x] = true;
  for (int j = h[x]; j; j = nxt[j])
    if (!vis[p[j]]) {
      rt = 0;
      maxx[rt] = inf;
      sum = siz[p[j]];
      calcsiz(p[j], -1);
      calcsiz(rt, -1);
      fa[rt] = x;
      dfs(p[j], -1, 1, dist[rt]);
      ch[x].insert(dist[rt].top());
      dep[rt] = dep[x] + 1;
      pre(rt);
    }
  ch[x].insert(0);
  if (ch[x].size() >= 2)
    ans.insert(ch[x].top() + ch[x].top2());
  else if (ch[x].size())
    ans.insert(ch[x].top());
}
struct LCA {
  int dep[maxn], lg[maxn], fa[maxn][20];
  void dfs(int x, int f) {
    for (int j = h[x]; j; j = nxt[j])
      if (p[j] != f) dep[p[j]] = dep[x] + 1, fa[p[j]][0] = x, dfs(p[j], x);
  }
  void init() {
    dfs(1, -1);
    for (int i = 2; i <= n; i++) lg[i] = lg[i / 2] + 1;
    for (int j = 1; j <= lg[n]; j++)
      for (int i = 1; i <= n; i++) fa[i][j] = fa[fa[i][j - 1]][j - 1];
  }
  int query(int x, int y) {
    if (dep[x] > dep[y]) swap(x, y);
    int k = dep[y] - dep[x];
    for (int i = 0; k; k = k / 2, i++)
      if (k & 1) y = fa[y][i];
    if (x == y) return x;
    k = dep[x];
    for (int i = lg[k]; i >= 0; i--)
      if (fa[x][i] != fa[y][i]) x = fa[x][i], y = fa[y][i];
    return fa[x][0];
  }
  int dist(int x, int y) { return dep[x] + dep[y] - 2 * dep[query(x, y)]; }
} lca;
int d[maxn][20];
int main() {
  scanf("%d", &n);
  for (int i = 1; i < n; i++)
    scanf("%d%d", &a, &b), add_edge(a, b), add_edge(b, a);
  lca.init();
  rt = 0;
  maxx[rt] = inf;
  sum = n;
  calcsiz(1, -1);
  calcsiz(rt, -1);
  pre(rt);
  // for(int i=1;i<=n;i++)printf("%d ",fa[i]);printf("\n");
  for (int i = 1; i <= n; i++)
    for (int j = i; j; j = fa[j]) d[i][dep[i] - dep[j]] = lca.dist(i, j);
  scanf("%d", &m);
  while (m--) {
    scanf(" %c", &op);
    if (op == 'G') {
      if (ans.size())
        printf("%d\n", ans.top());
      else
        printf("-1\n");
    } else {
      scanf("%d", &x);
      if (!col[x]) {
        if (ch[x].size() >= 2) ans.erase(ch[x].top() + ch[x].top2());
        ch[x].erase(0);
        if (ch[x].size() >= 2) ans.insert(ch[x].top() + ch[x].top2());
        for (int i = x; fa[i]; i = fa[i]) {
          if (ch[fa[i]].size() >= 2)
            ans.erase(ch[fa[i]].top() + ch[fa[i]].top2());
          ch[fa[i]].erase(dist[i].top());
          dist[i].erase(d[x][dep[x] - dep[fa[i]]]);
          if (dist[i].size()) ch[fa[i]].insert(dist[i].top());
          if (ch[fa[i]].size() >= 2)
            ans.insert(ch[fa[i]].top() + ch[fa[i]].top2());
        }
      } else {
        if (ch[x].size() >= 2) ans.erase(ch[x].top() + ch[x].top2());
        ch[x].insert(0);
        if (ch[x].size() >= 2) ans.insert(ch[x].top() + ch[x].top2());
        for (int i = x; fa[i]; i = fa[i]) {
          if (ch[fa[i]].size() >= 2)
            ans.erase(ch[fa[i]].top() + ch[fa[i]].top2());
          if (dist[i].size()) ch[fa[i]].erase(dist[i].top());
          dist[i].insert(d[x][dep[x] - dep[fa[i]]]);
          ch[fa[i]].insert(dist[i].top());
          if (ch[fa[i]].size() >= 2)
            ans.insert(ch[fa[i]].top() + ch[fa[i]].top2());
        }
      }
      col[x] ^= 1;
    }
  }
  return 0;
}
例题bzoj 3730 震波

给定一棵有 n 个结点的树,树上每个结点都有一个权值 v[x] 。实现以下两种操作:

  1. 询问与结点 x 距离不超过 y 的结点权值和;
  2. 修改结点 x 的点权为 y ,即 v[x]=y

我们用动态开点权值线段树记录距离信息。

类似于上题的思路,对于每个结点,我们维护线段树 dist[x] ,表示分治块 x 中的所有结点到结点 x 的距离信息,下标为距离,权值加上点权。线段树 ch[x] 表示分治块 x 中所有结点到结点 x 在分治树上的父亲结点的距离信息。

在本题中,所有查询和修改都需要在点分树上对所有祖先进行修改。

以查询操作为例,如果我们要查询距离结点 x 不超过 y 的结点的权值和,我们要先将答案加上线段树 dist[x] 中下标从 0y 的权值和,然后我们遍历 x 的所有祖先 u ,设其低一级祖先为 v ,令 d=dist(x,u) ,如果我们不进入包含 x 的子树,即以 v 为根的子树,那么我们要将答案加上线段树 dist[u] 中下标从 0y-d 的权值和。由于我们重复计算了以 v 为根的部分,我们要将答案减去线段树 ch[v] 中下标从 0y-d 的权值和。

在进行修改操作时,我们要同时维护 dist[x]ch[x]

参考代码:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#include <algorithm>
#include <cstdio>
#include <cstring>
using namespace std;
const int maxn = 100010;
const int inf = 2e9;
const int ddd = 6000010;
struct Segtree {
  int cnt, rt[maxn], sum[ddd], lc[ddd], rc[ddd];
  void update(int& o, int l, int r, int x, int v) {
    if (!o) o = ++cnt;
    if (l == r) {
      sum[o] += v;
      return;
    }
    int mid = (l + r) >> 1;
    if (x <= mid)
      update(lc[o], l, mid, x, v);
    else
      update(rc[o], mid + 1, r, x, v);
    sum[o] = sum[lc[o]] + sum[rc[o]];
  }
  int query(int o, int l, int r, int ql, int qr) {
    if (!o || r < ql || l > qr) return 0;
    if (ql <= l && r <= qr) return sum[o];
    int mid = (l + r) >> 1;
    return query(lc[o], l, mid, ql, qr) + query(rc[o], mid + 1, r, ql, qr);
  }
} dist, ch;
int n, m, val[maxn], u, v, op, x, y, lstans;
int cur, h[maxn * 2], nxt[maxn * 2], p[maxn * 2];
void add_edge(int x, int y) {
  cur++;
  nxt[cur] = h[x];
  h[x] = cur;
  p[cur] = y;
}
struct LCA {
  int dep[maxn], lg[maxn], fa[maxn][20];
  void dfs(int x, int f) {
    for (int j = h[x]; j; j = nxt[j])
      if (p[j] != f) dep[p[j]] = dep[x] + 1, fa[p[j]][0] = x, dfs(p[j], x);
  }
  void init() {
    dep[1] = 1;
    dfs(1, -1);
    for (int i = 2; i <= n; i++) lg[i] = lg[i / 2] + 1;
    for (int j = 1; j <= lg[n]; j++)
      for (int i = 1; i <= n; i++) fa[i][j] = fa[fa[i][j - 1]][j - 1];
  }
  int query(int x, int y) {
    if (dep[x] > dep[y]) swap(x, y);
    int k = dep[y] - dep[x];
    for (int i = 0; k; k = k / 2, i++)
      if (k & 1) y = fa[y][i];
    if (x == y) return x;
    k = dep[x];
    for (int i = lg[k]; i >= 0; i--)
      if (fa[x][i] != fa[y][i]) x = fa[x][i], y = fa[y][i];
    return fa[x][0];
  }
  int dist(int x, int y) { return dep[x] + dep[y] - 2 * dep[query(x, y)]; }
} lca;
int rt, sum, siz[maxn], maxx[maxn], fa[maxn];
int d[maxn][20], dep[maxn];
bool vis[maxn];
void calcsiz(int x, int fa) {
  siz[x] = 1;
  maxx[x] = 0;
  for (int j = h[x]; j; j = nxt[j])
    if (p[j] != fa && !vis[p[j]]) {
      calcsiz(p[j], x);
      siz[x] += siz[p[j]];
      maxx[x] = max(maxx[x], siz[p[j]]);
    }
  maxx[x] = max(maxx[x], sum - siz[x]);
  if (maxx[x] < maxx[rt]) rt = x;
}
void dfs1(int x, int fa, int y, int d) {
  ch.update(ch.rt[y], 0, n, d, val[x]);
  for (int j = h[x]; j; j = nxt[j])
    if (p[j] != fa && !vis[p[j]]) dfs1(p[j], x, y, d + 1);
}
void dfs2(int x, int fa, int y, int d) {
  dist.update(dist.rt[y], 0, n, d, val[x]);
  for (int j = h[x]; j; j = nxt[j])
    if (p[j] != fa && !vis[p[j]]) dfs2(p[j], x, y, d + 1);
}
void pre(int x) {
  vis[x] = true;
  dfs2(x, -1, x, 0);
  for (int j = h[x]; j; j = nxt[j])
    if (!vis[p[j]]) {
      rt = 0;
      maxx[rt] = inf;
      sum = siz[p[j]];
      calcsiz(p[j], -1);
      calcsiz(rt, -1);
      dfs1(p[j], -1, rt, 1);
      fa[rt] = x;
      dep[rt] = dep[x] + 1;
      pre(rt);
    }
}
int main() {
  scanf("%d%d", &n, &m);
  for (int i = 1; i <= n; i++) scanf("%d", val + i);
  for (int i = 1; i < n; i++)
    scanf("%d%d", &u, &v), add_edge(u, v), add_edge(v, u);
  lca.init();
  rt = 0;
  maxx[rt] = inf;
  sum = n;
  calcsiz(1, -1);
  calcsiz(rt, -1);
  pre(rt);
  // for(int i=1;i<=n;i++)printf("%d ",fa[i]);printf("\n");
  for (int i = 1; i <= n; i++)
    for (int j = i; j; j = fa[j]) d[i][dep[i] - dep[j]] = lca.dist(i, j);
  while (m--) {
    scanf("%d%d%d", &op, &x, &y);
    x ^= lstans;
    y ^= lstans;
    if (op == 0) {
      lstans = dist.query(dist.rt[x], 0, n, 0, y);
      int nww = 0;
      for (int i = x; fa[i]; i = fa[i]) {
        nww = d[x][dep[x] - dep[fa[i]]];  // lca.dist(x,fa[i]);
        lstans += dist.query(dist.rt[fa[i]], 0, n, 0, y - nww);
        lstans -= ch.query(ch.rt[i], 0, n, 0, y - nww);
      }
      printf("%d\n", lstans);
    }
    if (op == 1) {
      int nww = 0;
      dist.update(dist.rt[x], 0, n, 0, y - val[x]);
      for (int i = x; fa[i]; i = fa[i]) {
        nww = d[x][dep[x] - dep[fa[i]]];  // lca.dist(x,fa[i]);
        dist.update(dist.rt[fa[i]], 0, n, nww, y - val[x]);
        ch.update(ch.rt[i], 0, n, nww, y - val[x]);
      }
      val[x] = y;
    }
  }
  return 0;
}

评论