Skip to main content

Tree Algorithms

Why Trees Matter

Trees appear everywhere in CP: organizational hierarchies, file systems, game trees, and as subproblems in graph algorithms. A tree with n nodes has exactly n-1 edges and a unique path between any two nodes.
Pattern Recognition Signals:
  • “n nodes, n-1 edges” → It’s a tree
  • “Unique path between nodes” → Tree path queries
  • “Subtree sum/count” → DFS + subtree aggregation
  • “Path from u to v” → LCA + path decomposition
  • “Root the tree at node 1” → DFS from root

Tree Representations

// Most common: Adjacency list (unrooted tree)
int n;
vector<vector<int>> adj(n + 1);

for (int i = 0; i < n - 1; i++) {
    int u, v;
    cin >> u >> v;
    adj[u].push_back(v);
    adj[v].push_back(u);
}

// Rooted tree: Store parent and children
vector<int> parent(n + 1);
vector<vector<int>> children(n + 1);

void root(int u, int p) {
    parent[u] = p;
    for (int v : adj[u]) {
        if (v != p) {
            children[u].push_back(v);
            root(v, u);
        }
    }
}

DFS on Trees

Basic Template

vector<int> subtreeSize(n + 1);
vector<int> depth(n + 1);

void dfs(int u, int p) {
    subtreeSize[u] = 1;
    
    for (int v : adj[u]) {
        if (v != p) {
            depth[v] = depth[u] + 1;
            dfs(v, u);
            subtreeSize[u] += subtreeSize[v];
        }
    }
}

// Start DFS from root
depth[1] = 0;
dfs(1, 0);

Euler Tour (Flatten Tree to Array)

Converts tree to array for range queries on subtrees.
vector<int> tin(n + 1), tout(n + 1);
vector<int> euler;
int timer = 0;

void eulerTour(int u, int p) {
    tin[u] = timer++;
    euler.push_back(u);
    
    for (int v : adj[u]) {
        if (v != p) {
            eulerTour(v, u);
        }
    }
    
    tout[u] = timer;  // tout[u] = tin[u] + subtreeSize[u]
}

// Subtree of u corresponds to euler[tin[u]...tout[u]-1]
Application: Range queries on subtrees using segment trees.

Pattern 1: Tree Diameter

Problem: Find the longest path in a tree.

Two BFS/DFS Method

pair<int, int> bfs(int start) {  // Returns {farthest node, distance}
    vector<int> dist(n + 1, -1);
    queue<int> q;
    q.push(start);
    dist[start] = 0;
    
    int farthest = start, maxDist = 0;
    
    while (!q.empty()) {
        int u = q.front();
        q.pop();
        
        for (int v : adj[u]) {
            if (dist[v] == -1) {
                dist[v] = dist[u] + 1;
                q.push(v);
                
                if (dist[v] > maxDist) {
                    maxDist = dist[v];
                    farthest = v;
                }
            }
        }
    }
    
    return {farthest, maxDist};
}

int treeDiameter() {
    auto [u, _] = bfs(1);      // Find one endpoint
    auto [v, diameter] = bfs(u);  // Find other endpoint
    return diameter;
}

DP Method (Also Returns Path)

int diameter = 0;
int diameterEndpoint;

int dfs(int u, int p) {
    int maxDepth1 = 0, maxDepth2 = 0;
    
    for (int v : adj[u]) {
        if (v != p) {
            int childDepth = dfs(v, u) + 1;
            
            if (childDepth > maxDepth1) {
                maxDepth2 = maxDepth1;
                maxDepth1 = childDepth;
            } else if (childDepth > maxDepth2) {
                maxDepth2 = childDepth;
            }
        }
    }
    
    if (maxDepth1 + maxDepth2 > diameter) {
        diameter = maxDepth1 + maxDepth2;
    }
    
    return maxDepth1;
}

Pattern 2: Lowest Common Ancestor (LCA)

Problem: Find the lowest common ancestor of two nodes.

Binary Lifting

Precompute ancestors at powers of 2 for O(log n) queries.
const int LOG = 20;
vector<vector<int>> up(n + 1, vector<int>(LOG));
vector<int> depth(n + 1);

void preprocess(int root) {
    // BFS to compute depth and first ancestor
    queue<int> q;
    q.push(root);
    depth[root] = 0;
    up[root][0] = root;  // Parent of root is itself
    
    while (!q.empty()) {
        int u = q.front();
        q.pop();
        
        for (int v : adj[u]) {
            if (depth[v] == 0 && v != root) {
                depth[v] = depth[u] + 1;
                up[v][0] = u;
                q.push(v);
            }
        }
    }
    
    // Build sparse table
    for (int j = 1; j < LOG; j++) {
        for (int i = 1; i <= n; i++) {
            up[i][j] = up[up[i][j-1]][j-1];
        }
    }
}

int lca(int u, int v) {
    // Make u the deeper node
    if (depth[u] < depth[v]) swap(u, v);
    
    int diff = depth[u] - depth[v];
    
    // Lift u up to same depth as v
    for (int j = 0; j < LOG; j++) {
        if ((diff >> j) & 1) {
            u = up[u][j];
        }
    }
    
    if (u == v) return u;
    
    // Binary search for LCA
    for (int j = LOG - 1; j >= 0; j--) {
        if (up[u][j] != up[v][j]) {
            u = up[u][j];
            v = up[v][j];
        }
    }
    
    return up[u][0];
}

int dist(int u, int v) {
    return depth[u] + depth[v] - 2 * depth[lca(u, v)];
}
Codeforces Problems:
ProblemRatingLink
Company Queries I1300CSES
Company Queries II1400CSES

Pattern 3: Tree DP

Subtree DP

Problem: Compute something for each subtree.
// Example: Count nodes in each subtree with value > threshold
vector<int> subtreeCount(n + 1);

void dfs(int u, int p) {
    subtreeCount[u] = (val[u] > threshold) ? 1 : 0;
    
    for (int v : adj[u]) {
        if (v != p) {
            dfs(v, u);
            subtreeCount[u] += subtreeCount[v];
        }
    }
}

Rerooting DP

Problem: Compute answer for each node as if it were the root.
// Example: Sum of distances from each node to all other nodes
vector<long long> down(n + 1);  // Sum of distances to subtree
vector<long long> up(n + 1);    // Sum of distances to non-subtree
vector<int> sz(n + 1);

void dfs1(int u, int p) {  // Compute down[u] and sz[u]
    sz[u] = 1;
    down[u] = 0;
    
    for (int v : adj[u]) {
        if (v != p) {
            dfs1(v, u);
            sz[u] += sz[v];
            down[u] += down[v] + sz[v];
        }
    }
}

void dfs2(int u, int p) {  // Compute up[u]
    for (int v : adj[u]) {
        if (v != p) {
            // up[v] = up[u] + (contribution from u's other children)
            // + distance to u + (n - sz[v]) nodes above v
            up[v] = up[u] + down[u] - down[v] - sz[v] + (n - sz[v]);
            dfs2(v, u);
        }
    }
}

// Answer for node u = down[u] + up[u]

Pattern 4: Centroid Decomposition

Problem: Efficiently answer path queries on trees. The centroid is a node whose removal leaves no subtree larger than n/2.
vector<int> subtree(n + 1);
vector<bool> removed(n + 1, false);

int getSubtree(int u, int p) {
    subtree[u] = 1;
    for (int v : adj[u]) {
        if (v != p && !removed[v]) {
            subtree[u] += getSubtree(v, u);
        }
    }
    return subtree[u];
}

int getCentroid(int u, int p, int treeSize) {
    for (int v : adj[u]) {
        if (v != p && !removed[v] && subtree[v] > treeSize / 2) {
            return getCentroid(v, u, treeSize);
        }
    }
    return u;
}

void decompose(int u) {
    int treeSize = getSubtree(u, -1);
    int centroid = getCentroid(u, -1, treeSize);
    
    removed[centroid] = true;
    
    // Process centroid - all paths through centroid
    // ...
    
    for (int v : adj[centroid]) {
        if (!removed[v]) {
            decompose(v);
        }
    }
}

Pattern 5: Small-to-Large Merging

Problem: Aggregate information from children efficiently.
// Example: Count distinct values in each subtree
vector<set<int>> nodeSet(n + 1);
vector<int> answer(n + 1);

void dfs(int u, int p) {
    nodeSet[u].insert(val[u]);
    
    for (int v : adj[u]) {
        if (v != p) {
            dfs(v, u);
            
            // Merge smaller set into larger
            if (nodeSet[u].size() < nodeSet[v].size()) {
                swap(nodeSet[u], nodeSet[v]);
            }
            
            for (int x : nodeSet[v]) {
                nodeSet[u].insert(x);
            }
            nodeSet[v].clear();
        }
    }
    
    answer[u] = nodeSet[u].size();
}
Complexity: O(n log² n) total due to small-to-large merging.

Pattern 6: Tree Isomorphism

Problem: Check if two trees have the same structure.
map<vector<int>, int> hashMap;
int nextHash = 0;

int getHash(int u, int p) {
    vector<int> childHashes;
    
    for (int v : adj[u]) {
        if (v != p) {
            childHashes.push_back(getHash(v, u));
        }
    }
    
    sort(childHashes.begin(), childHashes.end());
    
    if (hashMap.find(childHashes) == hashMap.end()) {
        hashMap[childHashes] = nextHash++;
    }
    
    return hashMap[childHashes];
}

bool isIsomorphic(tree1, tree2) {
    // Trees are isomorphic if they have same hash when rooted at centroid
}

Common Mistakes

Mistake 1: Forgetting parent check in DFS
// WRONG - infinite loop
for (int v : adj[u]) {
    dfs(v, u);  // Will revisit parent!
}

// CORRECT
for (int v : adj[u]) {
    if (v != p) dfs(v, u);
}
Mistake 2: Wrong Euler tour indices Subtree of u is [tin[u], tout[u]), not [tin[u], tout[u]].
Mistake 3: LCA on unprocessed tree Always call preprocess(root) before LCA queries.

Practice Problems

Beginner (1000-1300)

ProblemPatternLink
Tree DiameterTwo BFSCSES
SubordinatesSubtree sizeCSES
Tree Distances ITree diameterCSES

Intermediate (1300-1600)

ProblemPatternLink
Distance QueriesLCACSES
Tree Distances IIRerootingCSES
Counting PathsEuler tour + differenceCSES

Advanced (1600-1900)

ProblemPatternLink
Distinct ColorsSmall-to-largeCSES
Path Queries IILCA + segment treeCSES

Key Takeaways

Euler Tour

Flattens tree to array for range queries on subtrees.

Binary Lifting

O(log n) LCA queries after O(n log n) preprocessing.

Rerooting DP

Compute answer for all roots in O(n) total.

Small-to-Large

Merge smaller sets into larger for O(n log² n).

Next Up

Chapter 16: Disjoint Set Union

Master Union-Find for dynamic connectivity and Kruskal’s MST.