给一棵含有 n n n 个结点的有根树,根结点为 1 1 1,编号为 i i i 的点有点权 a i a_i ai?( i ∈ [ 1 , n ] i \in [1,n] i∈[1,n])。现在有两种操作,格式如下:
现有长度为 m m m 的操作序列,请对于每个第二类操作给出正确的结果。
输入的第一行包含两个正整数 n , m n,m n,m,用一个空格分隔。
第二行包含 n n n 个整数 a 1 , a 2 , … , a n a_1,a_2,\ldots,a_n a1?,a2?,…,an?,相邻整数之间使用一个空格分隔。
接下来 n ? 1 n-1 n?1 行,每行包含两个正整数 u i , v i u_i,v_i ui?,vi?,表示结点 u i u_i ui? 和 v i v_i vi? 之间有一条边。
接下来 m m m 行,每行包含一个操作。
输出若干行,每行对应一个查询操作的答案。
4 4
1 2 3 4
1 2
1 3
2 4
2 1
1 1 0
2 1
2 2
4
5
6
对于 30 30 30% 的评测用例, n , m ≤ 1000 n,m \leq 1000 n,m≤1000;
对于所有评测用例, 1 ≤ n , m ≤ 100000 1 \leq n,m \leq 100000 1≤n,m≤100000, 0 ≤ a i , y ≤ 100000 0 \leq a_i,y \leq 100000 0≤ai?,y≤100000, 1 ≤ u i , v i , x ≤ n 1 \leq u_i,v_i,x \leq n 1≤ui?,vi?,x≤n。
考虑第二个操作,查询以节点 x x x 为根的子树内的所有点的点权的异或和。
类似这种子树查询问题,我们通常使用 DFS 序对树进行预处理。具体地说,在 DFS 遍历中,我们从根节点开始,依次遍历它的每个子节点。对于每个子节点,我们首先遍历它的子树,然后回溯到该子节点,继续遍历它的兄弟节点。在遍历的过程中,我们可以记录每个节点在 DFS 序中的遍历顺序,即第一次遍历到该节点的时间戳和最后一次遍历到该节点的时间戳。这里的时间戳可以使用一个计数器来实现,每次遍历到一个新节点时,计数器加 1 1 1,表示当前节点的时间戳。当回溯到该节点时,表示当前节点的最后一次遍历时间戳。
这样操作有什么作用呢?假设我们有一个长度大于 n n n 的数组 a a a,我们记进入每个点 i i i 的时间戳为 in [ i ] \text{in}[i] in[i],回溯到点 i i i 的时间戳为 out [ i ] \text{out}[i] out[i],同时将每个点的点权赋值到 a [ in [ i ] ] a[\text{in}[i]] a[in[i]] 上。这样对于一个根为 x x x 的子树内所有点的点权异或和就等价于 a a a 数组区间 [ in [ x ] , out [ x ] ] [\text{in}[x],\text{out}[x]] [in[x],out[x]] 的异或和。这样我们就将复杂的树上询问,转化为我们熟悉的数组区间查询问题。
接下来考虑操作 1 1 1,将点 x x x 的点权改为 y y y。
结合上述分析,该操作即是将 a [ in [ x ] ] a[\text{in}[x]] a[in[x]] 改为 y y y。
综上所述,我们需要对 a a a 数组进行一个单点修改和区间查询的操作,这个经典操作我们可以使用树状数组或者线段树来维护,代码中使用的是树状数组。具体地说,我们使用一个树状数组 a a a 来维护树的 DFS 序的前缀异或序列, a i a_i ai? 表示区间 [ 1 , i ] [1,i] [1,i] 的异或和。
时间复杂度为 O ( n log ? n ) O(n \log n) O(nlogn)。
#include<bits/stdc++.h>
using namespace std;
const int N = 100010;
template <typename T>
struct Fenwick {
int n;
std::vector<T> a;
Fenwick(int n = 0) {
init(n);
}
void init(int n) {
this->n = n;
a.assign(n + 1, T());
}
void add(int x, T v) {
for (; x <= n; x += x & (-x)) {
a[x] ^= v;
}
}
T sum(int x) {
auto ans = T();
for (; x; x -= x & (-x)) {
ans ^= a[x];
}
return ans;
}
T rangeSum(int l, int r) {
return sum(r) ^ sum(l);
}
};
int n, m, tot;
int a[N], in[N], out[N];
std::vector<int> e[N];
void dfs(int u, int fa) {
in[u] = ++tot;
for (auto v : e[u]) {
if (v == fa) continue;
dfs(v, u);
}
out[u] = tot;
}
int main()
{
ios_base :: sync_with_stdio(false);
cin.tie(0); cout.tie(0);
cin >> n >> m;
Fenwick<int> tr(n);
for (int i = 1; i <= n; ++i) cin >> a[i];
for (int i = 0; i < n - 1; ++i) {
int u, v;
cin >> u >> v;
e[u].push_back(v);
e[v].push_back(u);
}
dfs(1, 0);
for (int i = 1; i <= n; ++i) tr.add(in[i], a[i]);
int op, x, y;
for (int i = 0; i < m; ++i) {
cin >> op >> x;
if (op == 1) {
cin >> y;
int v = tr.rangeSum(in[x] - 1, in[x]);
tr.add(in[x], y ^ v);
} else {
cout << tr.rangeSum(in[x] - 1, out[x]) << '\n';
}
}
return 0;
}
import java.util.*;
import java.io.*;
public class Main {
static int N = 100010;
static int n, m, tot;
static int[] a = new int[N], in = new int[N], out = new int[N], b = new int[N];
static List<Integer>[] e = new List[N];
static void add(int x, int v) {
for (; x <= n; x += x & (-x)) {
b[x] ^= v;
}
}
static int sum(int x) {
int ans = 0;
if (x == 0) return 0;
for (; x > 0; x -= x & (-x)) {
ans ^= b[x];
}
return ans;
}
static int rangeSum(int l, int r) {
return sum(r) ^ sum(l);
}
static void dfs(int u, int fa) {
in[u] = ++tot;
for (int v : e[u]) {
if (v == fa) continue;
dfs(v, u);
}
out[u] = tot;
}
public static void main(String[] args) throws IOException {
BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(System.out));
String[] temp = reader.readLine().split(" ");
n = Integer.parseInt(temp[0]);
m = Integer.parseInt(temp[1]);
temp = reader.readLine().split(" ");
for (int i = 1; i <= n; ++i) {
a[i] = Integer.parseInt(temp[i - 1]);
e[i]=new ArrayList<>();
}
for (int i = 0; i < n - 1; ++i) {
temp = reader.readLine().split(" ");
int u = Integer.parseInt(temp[0]);
int v = Integer.parseInt(temp[1]);
e[u].add(v);
e[v].add(u);
}
dfs(1, 0);
for (int i = 1; i <= n; ++i) {
add(in[i], a[i]);
}
for (int i = 0; i < m; ++i) {
temp = reader.readLine().split(" ");
int op = Integer.parseInt(temp[0]);
int x = Integer.parseInt(temp[1]);
if (op == 1) {
int y = Integer.parseInt(temp[2]);
int v = rangeSum(in[x] - 1, in[x]);
add(in[x], y ^ v);
} else {
writer.write(rangeSum(in[x] - 1, out[x]) + "\n");
}
}
reader.close();
writer.flush();
writer.close();
}
}
import sys
N = 100010
n, m, tot = 0, 0, 0
a = [0]*N
in_ = [0]*N
out = [0]*N
b = [0]*N
e = [[] for _ in range(N)]
def add(x, v):
while x <= n:
b[x] ^= v
x += x & (-x)
def sum_(x):
ans = 0
if x == 0:
return ans
while x > 0:
ans ^= b[x]
x -= x & (-x)
return ans
def rangeSum(l, r):
return sum_(r) ^ sum_(l)
def dfs(u, fa):
global tot
in_[u] = tot = tot + 1
for v in e[u]:
if v == fa:
continue
dfs(v, u)
out[u] = tot
def main():
global n, m, tot
n, m = map(int, sys.stdin.readline().split())
a[1:n+1] = map(int, sys.stdin.readline().split())
for _ in range(n - 1):
u, v = map(int, sys.stdin.readline().split())
e[u].append(v)
e[v].append(u)
dfs(1, 0)
for i in range(1, n+1):
add(in_[i], a[i])
for _ in range(m):
op, x, *extra = map(int, sys.stdin.readline().split())
if op == 1:
y = extra[0]
v = rangeSum(in_[x] - 1, in_[x])
add(in_[x], y ^ v)
else:
print(rangeSum(in_[x] - 1, out[x]))
if __name__ == "__main__":
main()