Segment Tree

Table of Contents

Overview

Segment Tree is a data structure that allows efficient (ie, the asymptotic behavior \(O(log n)\) to implement the operation of the form: find the amount / minimum of the array of elements in a given interval (a[l..r] where land rare input to the algorithm), thus further possible to change the elements of the array: as a change in the value of one item, and change elements on a whole array of subsegments (ie, allowed to assign all the elements of a [l…r] any value or add to all the elements of the array any number).

  • An important feature of the segments of trees is that they consume a linear memory capacity.
  • Leaf Nodes are the elements of the input array.
  • Each internal node represents some merging of the leaf nodes. The merging may be different for different problems.
  • For each node at index i, the left child is at index 2*i+1, right child at 2*i+2 and the parent is at (i-1)/2.

A example of sum:

segment-tree-sum.png

Segment Tree(Sum of given range)1

// C program to show segment tree operations like construction, query
// and update
#include <stdio.h>
#include <math.h>

// A utility function to get the middle index from corner indexes.
int getMid(int s, int e) {  return s + (e -s)/2;  }

/*  A recursive function to get the sum of values in given range
    of the array. The following are parameters for this function.

    st    --> Pointer to segment tree
    si    --> Index of current node in the segment tree. Initially
              0 is passed as root is always at index 0
    ss & se  --> Starting and ending indexes of the segment represented
                 by current node, i.e., st[si]
    qs & qe  --> Starting and ending indexes of query range */
int getSumUtil(int *st, int ss, int se, int qs, int qe, int si)
{
    // If segment of this node is a part of given range, then return
    // the sum of the segment
    if (qs <= ss && qe >= se)
        return st[si];

    // If segment of this node is outside the given range
    if (se < qs || ss > qe)
        return 0;

    // If a part of this segment overlaps with the given range
    int mid = getMid(ss, se);
    return getSumUtil(st, ss, mid, qs, qe, 2*si+1) +
           getSumUtil(st, mid+1, se, qs, qe, 2*si+2);
}

/* A recursive function to update the nodes which have the given 
   index in their range. The following are parameters
    st, si, ss and se are same as getSumUtil()
    i    --> index of the element to be updated. This index is 
             in input array.
   diff --> Value to be added to all nodes which have i in range */
void updateValueUtil(int *st, int ss, int se, int i, int diff, int si)
{
    // Base Case: If the input index lies outside the range of 
    // this segment
    if (i < ss || i > se)
        return;

    // If the input index is in range of this node, then update 
    // the value of the node and its children
    st[si] = st[si] + diff;
    if (se != ss)
    {
        int mid = getMid(ss, se);
        updateValueUtil(st, ss, mid, i, diff, 2*si + 1);
        updateValueUtil(st, mid+1, se, i, diff, 2*si + 2);
    }
}

// The function to update a value in input array and segment tree.
// It uses updateValueUtil() to update the value in segment tree
void updateValue(int arr[], int *st, int n, int i, int new_val)
{
    // Check for erroneous input index
    if (i < 0 || i > n-1)
    {
        printf("Invalid Input");
        return;
    }

    // Get the difference between new value and old value
    int diff = new_val - arr[i];

    // Update the value in array
    arr[i] = new_val;

    // Update the values of nodes in segment tree
    updateValueUtil(st, 0, n-1, i, diff, 0);
}

// Return sum of elements in range from index qs (quey start)
// to qe (query end).  It mainly uses getSumUtil()
int getSum(int *st, int n, int qs, int qe)
{
    // Check for erroneous input values
    if (qs < 0 || qe > n-1 || qs > qe)
    {
        printf("Invalid Input");
        return -1;
    }

    return getSumUtil(st, 0, n-1, qs, qe, 0);
}

// A recursive function that constructs Segment Tree for array[ss..se].
// si is index of current node in segment tree st
int constructSTUtil(int arr[], int ss, int se, int *st, int si)
{
    // If there is one element in array, store it in current node of
    // segment tree and return
    if (ss == se)
    {
        st[si] = arr[ss];
        return arr[ss];
    }

    // If there are more than one elements, then recur for left and
    // right subtrees and store the sum of values in this node
    int mid = getMid(ss, se);
    st[si] =  constructSTUtil(arr, ss, mid, st, si*2+1) +
              constructSTUtil(arr, mid+1, se, st, si*2+2);
    return st[si];
}

/* Function to construct segment tree from given array. This function
   allocates memory for segment tree and calls constructSTUtil() to
   fill the allocated memory */
int *constructST(int arr[], int n)
{
    // Allocate memory for segment tree

    //Height of segment tree
    int x = (int)(ceil(log2(n))); 

    //Maximum size of segment tree
    int max_size = 2*(int)pow(2, x) - 1; 

    // Allocate memory
    int *st = new int[max_size];

    // Fill the allocated memory st
    constructSTUtil(arr, 0, n-1, st, 0);

    // Return the constructed segment tree
    return st;
}

// Driver program to test above functions
int main()
{
    int arr[] = {1, 3, 5, 7, 9, 11};
    int n = sizeof(arr)/sizeof(arr[0]);

    // Build segment tree from given array
    int *st = constructST(arr, n);

    // Print sum of values in array from index 1 to 3
    printf("Sum of values in given range = %d\n", 
            getSum(st, n, 1, 3));

    // Update: set arr[1] = 10 and update corresponding 
    // segment tree nodes
    updateValue(arr, st, n, 1, 10);

    // Find sum after the value is updated
    printf("Updated sum of values in given range = %d\n",
             getSum(st, n, 1, 3));
    return 0;
}

Lazy propagation

再更新一段数据时,利用lazy propagation,可以做到 \(O(logn)\)

节点数据向上更新

将子节点的值更新到父节点。

/* 对于区间求和 */
void push_up(int rt) {
    tree[rt] = tree[rt << 1] + tree[rt << 1 | 1];
}

/* 对于区间求最大值 */
void push_up(int rt) {
    tree[rt] = max(tree[rt << 1], tree[rt << 1 | 1]);
}

节点懒惰标记下推

对于区间求和, 原子数组值需要加上lazy标记乘以子树所统计的区间长度。 len 为父节点统计的区间长度, 则len - (len >> 1)为左子树区间长度, len >> 1为右子树区间长度。

void push_down(int rt, int len) {
    tree[rt << 1] += lazy[rt] * (len - (len >> 1));
    lazy[rt << 1] += lazy[rt];
    tree[rt << 1 | 1] += lazy[rt] * (len >> 1);
    lazy[rt << 1 | 1] += lazy[rt];
    lazy[rt] = 0;
}

对于区间求最大值, 子树的值不需要乘以长度, 所以不需要传递参数len。

void push_down(int rt) {
    tree[rt << 1] += lazy[rt];
    lazy[rt << 1] += lazy[rt];
    tree[rt << 1 | 1] += lazy[rt];
    lazy[rt << 1 | 1] += lazy[rt];
    lazy[rt] = 0;
}

建树

新建一棵长度N的线段树。

#define lchild rt << 1, l, m
#define rchild rt << 1 | 1, m + 1, r
void build(int rt = 1, int l = 1, int r = N) {
    if (l == r) { std::cin >> tree[rt]; return; }
    int m = (l + r) >> 1;
    build(lchild); build(rchild);
    push_up(rt);
}

更新

单点更新, 不需要用到lazy标记

#define lchild rt << 1, l, m
#define rchild rt << 1 | 1, m + 1, r
void update(int p, int delta, int rt = 1, int l = 1, int r = N) {
    if (l == r) {
        tree[rt] += delta;
        return;
    }
    int m = (l + r) >> 1;
    if (p <= m) update(p, delta, lchild);
    else update(p, delta, rchild);
    push_up(rt);
}

成段更新, 需要用到lazy标记来提高时间效率

#define lchild rt << 1, l, m
#define rchild rt << 1 | 1, m + 1, r
void update(int L, int R, int delta, int rt = 1, int l = 1, int r = N) {
    if (L <= l && r <= R) {
        tree[rt] += delta * (r - l + 1);
        lazy[rt] += delta;
        return;
    }
    if (lazy[rt]) push_down(rt, r - l + 1);
    int m = (l + r) >> 1;
    if (L <= m) update(L, R, delta, lchild);
    if (R > m)  update(L, R, delta, rchild);
    push_up(rt);
}

区间查询

#define lchild rt << 1, l, m
#define rchild rt << 1 | 1, m + 1, r
int query(int L, int R, int rt = 1, int l = 1, int r = N) {
    if (L <= l && r <= R) return tree[rt];
    if (lazy[rt]) push_down(rt, r - l + 1);
    int m = (l + r) >> 1, ret = 0;
    if (L <= m) ret += query(L, R, lchild);
    if (R > m)  ret += query(L, R, rchild);
    return ret;
}

More

Footnotes:

Author: Shi Shougang

Created: 2017-04-11 Tue 23:16

Emacs 24.3.1 (Org mode 8.2.10)

Validate