线段树

Aki 发布于 2023-01-08 267 次阅读


介绍、

线段树是一种高级数据结构,也是一种树结构,准确的说是二叉树。它能够高效的处理区间修改查询等问题。因此学习线段树,我们就是在学习一种全新的高级数据结构,学习它如何组织数据,然后高效的进行数据查询,修改等操作。

首先线段树是一棵二叉树, 平常我们所指的线段树都是指一维线段树。 故名思义, 线段树能解决的是线段上的问题, 这个线段也可指区间

线段树的一个节点可以存好多东西,如一个区间,区间和,区间最大值,区间最小值等等。

如下结构体设计。线段树常用来查询某一个区间的和,时间复杂度是O(logn),修改某一个区间的和也是O(logn),单点修改也是O(logn),是一种很快的数据结构,用在解题中。

代码实现、

#include<iostream>
using namespace std;
#include<cassert>
#include<algorithm>
#include<queue>
#include<numeric>


//线段树节点结构体
class Segment_Tree_Node
{
public:

        //构造函数
	Segment_Tree_Node(int begin,int end,Segment_Tree_Node* parent):begin(begin),end(end),parent(parent){}
	~Segment_Tree_Node()noexcept{}

	int begin = 0;  //区间开始
	int end = 0;    //区间结束
	int sum = 0;    //区间和
	Segment_Tree_Node* left = nullptr;   //左子树
	Segment_Tree_Node* right = nullptr;  //右子树
	Segment_Tree_Node* parent = nullptr; //父亲节点
};


//线段树结构体
class Segment_Tree
{
public:

	Segment_Tree():array(nullptr),_root(nullptr), array_size(0){}

	~Segment_Tree()noexcept
	{
		clear();
	}

	void clear()
	{
		_destroy_(_root);
		_root  = nullptr;
		if (array)
		{
			delete[]array;
		}
		array = nullptr;
		array_size = 0;
	}

        
        //使用一个数组来初始化线段树
	void initialize(int* arr,int size)
	{
                if (array)
		{
			delete[]array;
		}
		array = arr;
		array_size = size;
		_root = _build(0, size - 1,nullptr);
	}

        //BFS
	void BFS()
	{
		if (_root == nullptr)
		{
			return;
		}
		queue<Segment_Tree_Node*> q{};
		q.push(_root);
		Segment_Tree_Node* tmp = nullptr;
		while (!q.empty())
		{
			tmp = q.front();
			q.pop();
			cout << tmp->sum << "   ";
			if (tmp->left)
			{
				q.push(tmp->left);
			}
			if (tmp->right)
			{
				q.push(tmp->right);
			}
		}
		cout << endl;
	}

        //数组区间求和,范围是[begin,end]
	int sum(int begin, int end)
	{
                assert(begin >= 0 && begin < array_size);
		assert(end >= 0 && end < array_size);
		assert(begin <= end);
		return _sum_(begin, end, _root);
	}

        //单点更新,位置是index,更新后的值是val
	void update(int index, int val)
	{
            	assert(index >= 0 && index < array_size);
		int cha = val - array[index];
		array[index] = val;
		Segment_Tree_Node* pos = _find(_root,index);
		if (pos == nullptr)
		{
			return;
		}
		pos->sum = val;
		while (pos->parent)
		{
			pos = pos->parent;
			pos->sum += cha;
		}
	}

	
protected:


	int _sum_(int begin, int end, Segment_Tree_Node* _p)
	{
		if (begin <= _p->begin && end >= _p->end)
		{
			return _p->sum;
		}
		int mid = (_p->begin + _p->end) / 2;
		int sum = 0;
		if (begin <= mid)
		{
			sum += _sum_(begin, end, _p->left);
		}
		if (end > mid)
		{
			sum += _sum_(begin, end, _p->right);
		}
		return sum;
	}


	Segment_Tree_Node* _find(Segment_Tree_Node* _p,int index)
	{
		if (_p->begin == index && _p->end == index )
		{
			return _p;
		}
		int mid = (_p->begin + _p->end) / 2;
		if (index <= mid)
		{
			return _find(_p->left, index);
		}
		else if (index > mid)
		{
			return _find(_p->right, index);
		}
		return nullptr;
	}


	void _destroy_(Segment_Tree_Node* _p)
	{
		if (_p == nullptr)
		{
			return;
		}
		_destroy_(_p->left);
		_destroy_(_p->right);
		delete _p;
	}


        //线段树构造,递归方式
	Segment_Tree_Node* _build(int begin, int end,Segment_Tree_Node*parent)
	{
		if (begin > end)
		{
			return nullptr;
		}
		Segment_Tree_Node* _p = new Segment_Tree_Node{begin,end,parent};
		_p->sum = accumulate(array + begin, array + end + 1, 0);
		if (begin == end)
		{
			return _p;
		}
		int mid = (end + begin) / 2;
		_p->left = _build(begin, mid, _p);
		_p->right = _build(mid + 1, end, _p);
		return _p;
	}


private:
	int* array;
	int array_size;
	Segment_Tree_Node* _root;
};

int main()
{

	int *array = new int[]{ 0,1,3,4,5,7,9,10 };
	Segment_Tree t{};
	t.initialize(array,8);

	t.BFS();

	t.update(7, 25);
	t.update(0, 5);

	t.BFS();

	cout << t.sum(0, 7) << endl;
	cout << t.sum(2, 5) << endl;

	

	return 0;
}