我正在用std::unique_ptr
玩红黑树,但它不起作用。
我的节点定义:
enum class Color {
Red,
Black
};
template <typename T>
struct Node {
T key;
Color color;
std::unique_ptr<Node<T>> left;
std::unique_ptr<Node<T>> right;
Node<T>* parent;
Node(const T& key) : key {key}, parent {nullptr}, color {Color::Red} {}
};
我选择std::unique_ptr
是因为std::shared_ptr
开销很大,并且父级拥有它的左右子级。 简单地说,parent
应该是一个原始指针。
然而,insert
函数背后的逻辑打破了我的基本设计:
下面是我的树旋转函数。 它们接受std::unique_ptr
的rvalue引用,因为它实际上转移了所有权。
void LeftRotate(std::unique_ptr<Node<T>>&& x) {
auto y = std::move(x->right);
auto yl = y->left.get();
x->right = std::move(y->left);
if (yl) {
yl->parent = x.get();
}
y->parent = x->parent;
auto py = y.get();
if (!x->parent) {
root = std::move(y);
} else if (x == x->parent->left) {
x->parent->left = std::move(y);
} else {
x->parent->right = std::move(y);
}
x->parent = py;
py->left = std::move(x);
}
void RightRotate(std::unique_ptr<Node<T>>&& x) {
auto y = std::move(x->left);
auto yr = y->right.get();
x->left = std::move(y->right);
if (yr) {
yr->parent = x.get();
}
y->parent = x->parent;
auto py = y.get();
if (!x->parent) {
root = std::move(y);
} else if (x == x->parent->left) {
x->parent->left = std::move(y);
} else {
x->parent->right = std::move(y);
}
x->parent = py;
py->right = std::move(x);
}
下面是我的insert
函数:
public:
void Insert(const T& key) {
auto z = std::make_unique<Node<T>>(key);
Insert(std::move(z));
}
private:
void Insert(std::unique_ptr<Node<T>> z) {
Node<T>* y = nullptr;
Node<T>* x = root.get();
while (x) {
y = x;
if (z->key < x->key) {
x = x->left.get();
} else {
x = x->right.get();
}
}
z->parent = y;
if (!y) {
root = std::move(z);
InsertFixup(std::move(root));
} else if (z->key < y->key) {
y->left = std::move(z);
InsertFixup(std::move(y->left));
} else {
y->right = std::move(z);
InsertFixup(std::move(y->right));
}
}
void InsertFixup(std::unique_ptr<Node<T>>&& z) {
auto zp = z->parent;
while (zp && zp->color == Color::Red) {
auto zpp = zp->parent;
if (zp == zpp->left.get()) {
auto y = zpp->right.get();
if (y && y->color == Color::Red) {
zp->color = Color::Black;
y->color = Color::Black;
zpp->color = Color::Red;
zp = zpp->parent;
} else {
if (z == zp->right) {
z = std::unique_ptr<Node<T>>(zp);
auto pz = z.get();
LeftRotate(std::move(z));
zp = pz->parent;
zpp = zp->parent;
}
zp->color = Color::Black;
zpp->color = Color::Red;
auto pzpp = std::unique_ptr<Node<T>>(zpp); // error
RightRotate(std::move(pzpp)); // error
}
} else {
auto y = zpp->left.get();
if (y && y->color == Color::Red) {
zp->color = Color::Black;
y->color = Color::Black;
zpp->color = Color::Red;
zp = zpp->parent;
} else {
if (z == zp->left) {
z = std::unique_ptr<Node<T>>(zp);
auto pz = z.get();
RightRotate(std::move(z));
zp = pz->parent;
zpp = zp->parent;
}
zp->color = Color::Black;
zpp->color = Color::Red;
auto pzpp = std::unique_ptr<Node<T>>(zpp); // error
LeftRotate(std::move(pzpp)); // error
}
}
}
root->color = Color::Black;
}
InsertFixup
中的以下行存在错误:
auto pzpp = std::unique_ptr<Node<T>>(zpp); // error
LeftRotate(std::move(pzpp)); // error
我要做的是围绕节点z
的祖母旋转树。
但是,问题是不可能获得拥有祖母节点(需要传递给leftrotate
函数)的std::unique_ptr
,因为我的节点实现的父链接给出了一个原始指针。 当然,我可以从根追踪下来get,但是这样做会打破RB-Tree的插入操作的对数时间复杂度,使其形同虚设。
是否应该改用std::shared_ptr
? 有没有方法用std::unique_ptr
实现来实现RB-tree?
现在我对insert
进行了正确的实现。 它与随机顺序插入测试一起工作得很好。
当前工作:
#include <cassert>
#include <iostream>
#include <memory>
#include <utility>
#include <numeric>
#include <vector>
#include <random>
std::mt19937 gen(std::random_device{}());
enum class Color {
Red,
Black
};
template <typename T>
struct Node {
T key;
Color color;
std::unique_ptr<Node<T>> left;
std::unique_ptr<Node<T>> right;
Node<T>* parent;
Node(const T& key) : key {key}, parent {nullptr}, color {Color::Red} {}
};
template <typename T>
struct RBTree {
public:
std::unique_ptr<Node<T>> root;
private:
void LeftRotate(std::unique_ptr<Node<T>>&& x) {
auto y = std::move(x->right);
x->right = std::move(y->left);
if (x->right) {
x->right->parent = x.get();
}
y->parent = x->parent;
auto xp = x->parent;
if (!xp) {
auto px = x.release();
root = std::move(y);
root->left = std::unique_ptr<Node<T>>(px);
root->left->parent = root.get();
} else if (x == xp->left) {
auto px = x.release();
xp->left = std::move(y);
xp->left->left = std::unique_ptr<Node<T>>(px);
xp->left->left->parent = xp->left.get();
} else {
auto px = x.release();
xp->right = std::move(y);
xp->right->left = std::unique_ptr<Node<T>>(px);
xp->right->left->parent = xp->right.get();
}
}
void RightRotate(std::unique_ptr<Node<T>>&& x) {
auto y = std::move(x->left);
x->left = std::move(y->right);
if (x->left) {
x->left->parent = x.get();
}
y->parent = x->parent;
auto xp = x->parent;
if (!xp) {
auto px = x.release();
root = std::move(y);
root->right = std::unique_ptr<Node<T>>(px);
root->right->parent = root.get();
} else if (x == xp->left) {
auto px = x.release();
xp->left = std::move(y);
xp->left->right = std::unique_ptr<Node<T>>(px);
xp->left->right->parent = xp->left.get();
} else {
auto px = x.release();
xp->right = std::move(y);
xp->right->right = std::unique_ptr<Node<T>>(px);
xp->right->right->parent = xp->right.get();
}
}
public:
void Insert(const T& key) {
auto z = std::make_unique<Node<T>>(key);
Insert(std::move(z));
}
private:
void Insert(std::unique_ptr<Node<T>> z) {
Node<T>* y = nullptr;
Node<T>* x = root.get();
while (x) {
y = x;
if (z->key < x->key) {
x = x->left.get();
} else {
x = x->right.get();
}
}
z->parent = y;
if (!y) {
root = std::move(z);
InsertFixup(std::move(root));
} else if (z->key < y->key) {
y->left = std::move(z);
InsertFixup(std::move(y->left));
} else {
y->right = std::move(z);
InsertFixup(std::move(y->right));
}
}
void InsertFixup(std::unique_ptr<Node<T>>&& z) {
auto zp = z->parent;
while (zp && zp->color == Color::Red) {
auto zpp = zp->parent;
if (zp == zpp->left.get()) {
auto y = zpp->right.get();
if (y && y->color == Color::Red) {
zp->color = Color::Black;
y->color = Color::Black;
zpp->color = Color::Red;
zp = zpp->parent;
} else {
if (z == zp->right) {
LeftRotate(std::move(zpp->left));
zp = zpp->left.get();
}
zp->color = Color::Black;
zpp->color = Color::Red;
auto zppp = zpp->parent;
if (!zppp) {
RightRotate(std::move(root));
} else if (zpp == zppp->left.get()) {
RightRotate(std::move(zppp->left));
} else {
RightRotate(std::move(zppp->right));
}
}
} else {
auto y = zpp->left.get();
if (y && y->color == Color::Red) {
zp->color = Color::Black;
y->color = Color::Black;
zpp->color = Color::Red;
zp = zpp->parent;
} else {
if (z == zp->left) {
RightRotate(std::move(zpp->right));
zp = zpp->right.get();
}
zp->color = Color::Black;
zpp->color = Color::Red;
auto zppp = zpp->parent;
if (!zppp) {
LeftRotate(std::move(root));
} else if (zpp == zppp->left.get()) {
LeftRotate(std::move(zppp->left));
} else {
LeftRotate(std::move(zppp->right));
}
}
}
}
root->color = Color::Black;
}
};
template <typename T>
std::ostream& operator<<(std::ostream& os, Node<T>* node) {
if (node) {
os << node->left.get();
os << node->key;
if (node->color == Color::Black) {
os << "● ";
} else {
os << "○ ";
}
os << node->right.get();
}
return os;
}
template <typename T>
std::ostream& operator<<(std::ostream& os, const RBTree<T>& tree) {
os << tree.root.get();
return os;
}
int main() {
constexpr size_t SIZE = 30;
std::vector<int> v (SIZE);
std::iota(v.begin(), v.end(), 1);
std::shuffle(v.begin(), v.end(), gen);
RBTree<int> rbtree;
for (auto n : v) { // random order insert test
rbtree.Insert(n);
}
std::cout << rbtree;
}