1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
| class Seg_Tree { protected: struct Node { int sum; Node* ls, * rs; Node() : sum(0), ls(nullptr), rs(nullptr) {} Node(int val) : sum(val), ls(nullptr), rs(nullptr) {} }; Node* root; int n; void pushup(Node* p) { if (!p) return; p->sum = 0; if (p->ls) p->sum += p->ls->sum; if (p->rs) p->sum += p->rs->sum; } void add(Node*& p, int l, int r, int x, int val = 1) { if (p == nullptr) p = new Node(); if (l == r) { p->sum += val; return; } int mid = (l + r) >> 1; if (x <= mid) add(p->ls, l, mid, x, val); else add(p->rs, mid + 1, r, x, val); pushup(p); } int query(Node*& p, int l, int r, int x, int y) { if (p == nullptr) return 0; if (x > r || y < l) return 0; if (x <= l && r <= y) return p->sum; int mid = (l + r) >> 1; return query(p->ls, l, mid, x, y) + query(p->rs, mid + 1, r, x, y); } int query(Node*& p, int l, int r, int k) { if (p == nullptr) return 0; if (l == r) return l; int mid = (l + r) >> 1; if (p->ls && p->ls->sum >= k) return query(p->ls, l, mid, k); else return query(p->rs, mid + 1, r, k - (p->ls ? p->ls->sum : 0)); } Node* merge(Node*& p, Node*& q) { if (p == nullptr) return q; if (q == nullptr) return p; Node* r = new Node(); r->sum = p->sum + q->sum; r->ls = merge(p->ls, q->ls); r->rs = merge(p->rs, q->rs); return r; } Node* split_idx(Node*& p, int l, int r, int x, int y) { if (p == nullptr) return nullptr; if (x > r || y < l) return nullptr; if (x <= l && r <= y) { Node* q = new Node(); q->sum = p->sum; q->ls = p->ls; q->rs = p->rs; p->sum = 0; p->ls = nullptr; p->rs = nullptr; return q; } int mid = (l + r) >> 1; Node* q = new Node(); q->ls = split_idx(p->ls, l, mid, x, y); q->rs = split_idx(p->rs, mid + 1, r, x, y); q->sum = (q->ls ? q->ls->sum : 0) + (q->rs ? q->rs->sum : 0); if (q->ls == nullptr && q->rs == nullptr) { delete q; return nullptr; } p->sum -= q->sum; if (p->ls == nullptr && p->rs == nullptr) { delete p; p = nullptr; } return q; } Node* split_val(Node*& p, int l, int r, int x, int y) { if (p == nullptr) return nullptr; if (x > r || y < l) return nullptr; if (p->sum == y && x == 1) { Node* q = new Node(); q->sum = p->sum; q->ls = p->ls; q->rs = p->rs; p->sum = 0; p->ls = nullptr; p->rs = nullptr; return q; } int mid = (l + r) >> 1; Node* q = new Node(); q->ls = split_val(p->ls, l, mid, x, y); q->rs = split_val(p->rs, mid + 1, r, x, y); q->sum = (q->ls ? q->ls->sum : 0) + (q->rs ? q->rs->sum : 0); if (q->ls == nullptr && q->rs == nullptr) { delete q; return nullptr; } p->sum -= q->sum; if (p->ls == nullptr && p->rs == nullptr) { delete p; p = nullptr; } return q; } void debug(Node* p, int l, int r) { if (p == nullptr) return; if (l == r) { cerr << l << " " << p->sum << endl; return; } int mid = (l + r) >> 1; debug(p->ls, l, mid); debug(p->rs, mid + 1, r); } public: Seg_Tree() : root(nullptr), n(0) {} int size() { if (root) return root->sum; else return -1; } void build(int n_) { n = n_; } void add(int x, int val = 1) { add(root, 1, n, x, val); } int query(int x, int y) { return query(root, 1, n, x, y); } int query(int k) { return query(root, 1, n, k); } void merge(Seg_Tree& other) { root = merge(root, other.root); } Seg_Tree split_idx(int x, int y) { Seg_Tree other; other.root = split_idx(root, 1, n, x, y); other.n = this->n; return other; } Seg_Tree split_val(int x, int y) { Seg_Tree other; other.root = split_val(root, 1, n, x, y); other.n = this->n; return other; } void debug() { debug(root, 1, n); } };
|