提问者:小点点

带有std::unique_ptr的红黑树


我正在用std::unique_ptr玩红黑树,但它不起作用。

我的节点定义:

enum class Color {
    Red,
    Black
};

template <typename T>
struct Node {
    T key;
    Color color;
    std::unique_ptr<Node<T>> left;
    std::unique_ptr<Node<T>> right;
    Node<T>* parent;

    Node(const T& key) : key {key}, parent {nullptr}, color {Color::Red} {}
};

我选择std::unique_ptr是因为std::shared_ptr开销很大,并且父级拥有它的左右子级。 简单地说,parent应该是一个原始指针。

然而,insert函数背后的逻辑打破了我的基本设计:

下面是我的树旋转函数。 它们接受std::unique_ptr的rvalue引用,因为它实际上转移了所有权。

    void LeftRotate(std::unique_ptr<Node<T>>&& x) {
        auto y = std::move(x->right);
        auto yl = y->left.get();
        x->right = std::move(y->left);
        if (yl) {
            yl->parent = x.get();
        }
        y->parent = x->parent;
        auto py = y.get();
        if (!x->parent) {
            root = std::move(y);
        } else if (x == x->parent->left) {
            x->parent->left = std::move(y);
        } else {
            x->parent->right = std::move(y);
        }
        x->parent = py;
        py->left = std::move(x);
    }

    void RightRotate(std::unique_ptr<Node<T>>&& x) {
        auto y = std::move(x->left);
        auto yr = y->right.get();
        x->left = std::move(y->right);
        if (yr) {
            yr->parent = x.get();
        }
        y->parent = x->parent;
        auto py = y.get();
        if (!x->parent) {
            root = std::move(y);
        } else if (x == x->parent->left) {
            x->parent->left = std::move(y);
        } else {
            x->parent->right = std::move(y);
        }
        x->parent = py;
        py->right = std::move(x);
    }

下面是我的insert函数:

public:
    void Insert(const T& key) {
        auto z = std::make_unique<Node<T>>(key);
        Insert(std::move(z));
    }

private:
    void Insert(std::unique_ptr<Node<T>> z) {
        Node<T>* y = nullptr;
        Node<T>* x = root.get();
        while (x) {
            y = x;
            if (z->key < x->key) {
                x = x->left.get();
            } else {
                x = x->right.get();
            }
        }
        z->parent = y;
        if (!y) {
            root = std::move(z);
            InsertFixup(std::move(root));
        } else if (z->key < y->key) {
            y->left = std::move(z);
            InsertFixup(std::move(y->left));
        } else {
            y->right = std::move(z);
            InsertFixup(std::move(y->right));
        }
    }

    void InsertFixup(std::unique_ptr<Node<T>>&& z) {
        auto zp = z->parent;
        while (zp && zp->color == Color::Red) {
            auto zpp = zp->parent;
            if (zp == zpp->left.get()) {
                auto y = zpp->right.get();
                if (y && y->color == Color::Red) {
                    zp->color = Color::Black;
                    y->color = Color::Black;
                    zpp->color = Color::Red;
                    zp = zpp->parent;
                } else {
                    if (z == zp->right) {
                        z = std::unique_ptr<Node<T>>(zp);
                        auto pz = z.get();
                        LeftRotate(std::move(z));
                        zp = pz->parent;
                        zpp = zp->parent;
                    }
                    zp->color = Color::Black;
                    zpp->color = Color::Red;
                    auto pzpp = std::unique_ptr<Node<T>>(zpp); // error
                    RightRotate(std::move(pzpp)); // error
                }
            } else {
                auto y = zpp->left.get();
                if (y && y->color == Color::Red) {
                    zp->color = Color::Black;
                    y->color = Color::Black;
                    zpp->color = Color::Red;
                    zp = zpp->parent;
                } else {
                    if (z == zp->left) {
                        z = std::unique_ptr<Node<T>>(zp);
                        auto pz = z.get();
                        RightRotate(std::move(z));
                        zp = pz->parent;
                        zpp = zp->parent;
                    }
                    zp->color = Color::Black;
                    zpp->color = Color::Red;
                    auto pzpp = std::unique_ptr<Node<T>>(zpp); // error
                    LeftRotate(std::move(pzpp)); // error
                }
            }
        }
        root->color = Color::Black;
    }

InsertFixup中的以下行存在错误:

auto pzpp = std::unique_ptr<Node<T>>(zpp); // error
LeftRotate(std::move(pzpp)); // error

我要做的是围绕节点z的祖母旋转树。

但是,问题是不可能获得拥有祖母节点(需要传递给leftrotate函数)的std::unique_ptr,因为我的节点实现的父链接给出了一个原始指针。 当然,我可以从根追踪下来get,但是这样做会打破RB-Tree的插入操作的对数时间复杂度,使其形同虚设。

是否应该改用std::shared_ptr? 有没有方法用std::unique_ptr实现来实现RB-tree?


共1个答案

匿名用户

现在我对insert进行了正确的实现。 它与随机顺序插入测试一起工作得很好。

当前工作:

#include <cassert>
#include <iostream>
#include <memory>
#include <utility>
#include <numeric>
#include <vector>
#include <random>

std::mt19937 gen(std::random_device{}());

enum class Color {
    Red,
    Black
};

template <typename T>
struct Node {
    T key;
    Color color;
    std::unique_ptr<Node<T>> left;
    std::unique_ptr<Node<T>> right;
    Node<T>* parent;

    Node(const T& key) : key {key}, parent {nullptr}, color {Color::Red} {}
};

template <typename T>
struct RBTree {
public:
    std::unique_ptr<Node<T>> root;

private:
    void LeftRotate(std::unique_ptr<Node<T>>&& x) {
        auto y = std::move(x->right);
        x->right = std::move(y->left);
        if (x->right) {
            x->right->parent = x.get();
        }
        y->parent = x->parent;
        auto xp = x->parent;
        if (!xp) {
            auto px = x.release();
            root = std::move(y);
            root->left = std::unique_ptr<Node<T>>(px);
            root->left->parent = root.get();
        } else if (x == xp->left) {
            auto px = x.release();
            xp->left = std::move(y);
            xp->left->left = std::unique_ptr<Node<T>>(px);
            xp->left->left->parent = xp->left.get();
        } else {
            auto px = x.release();
            xp->right = std::move(y);
            xp->right->left = std::unique_ptr<Node<T>>(px);
            xp->right->left->parent = xp->right.get();
        }
    }

    void RightRotate(std::unique_ptr<Node<T>>&& x) {
        auto y = std::move(x->left);
        x->left = std::move(y->right);
        if (x->left) {
            x->left->parent = x.get();
        }
        y->parent = x->parent;
        auto xp = x->parent;
        if (!xp) {
            auto px = x.release();
            root = std::move(y);
            root->right = std::unique_ptr<Node<T>>(px);
            root->right->parent = root.get();
        } else if (x == xp->left) {
            auto px = x.release();
            xp->left = std::move(y);
            xp->left->right = std::unique_ptr<Node<T>>(px);
            xp->left->right->parent = xp->left.get();
        } else {
            auto px = x.release();
            xp->right = std::move(y);
            xp->right->right = std::unique_ptr<Node<T>>(px);
            xp->right->right->parent = xp->right.get();
        }
    }

public:
    void Insert(const T& key) {
        auto z = std::make_unique<Node<T>>(key);
        Insert(std::move(z));
    }

private:
    void Insert(std::unique_ptr<Node<T>> z) {
        Node<T>* y = nullptr;
        Node<T>* x = root.get();
        while (x) {
            y = x;
            if (z->key < x->key) {
                x = x->left.get();
            } else {
                x = x->right.get();
            }
        }
        z->parent = y;
        if (!y) {
            root = std::move(z);
            InsertFixup(std::move(root));
        } else if (z->key < y->key) {
            y->left = std::move(z);
            InsertFixup(std::move(y->left));
        } else {
            y->right = std::move(z);
            InsertFixup(std::move(y->right));
        }
    }

    void InsertFixup(std::unique_ptr<Node<T>>&& z) {
        auto zp = z->parent;
        while (zp && zp->color == Color::Red) {
            auto zpp = zp->parent;
            if (zp == zpp->left.get()) {
                auto y = zpp->right.get();
                if (y && y->color == Color::Red) {
                    zp->color = Color::Black;
                    y->color = Color::Black;
                    zpp->color = Color::Red;
                    zp = zpp->parent;
                } else {
                    if (z == zp->right) {
                        LeftRotate(std::move(zpp->left));
                        zp = zpp->left.get();
                    }
                    zp->color = Color::Black;
                    zpp->color = Color::Red;
                    auto zppp = zpp->parent;
                    if (!zppp) {
                        RightRotate(std::move(root));
                    } else if (zpp == zppp->left.get()) {
                        RightRotate(std::move(zppp->left));
                    } else {
                        RightRotate(std::move(zppp->right));
                    }
                }
            } else {
                auto y = zpp->left.get();
                if (y && y->color == Color::Red) {
                    zp->color = Color::Black;
                    y->color = Color::Black;
                    zpp->color = Color::Red;
                    zp = zpp->parent;
                } else {
                    if (z == zp->left) {
                        RightRotate(std::move(zpp->right));
                        zp = zpp->right.get();
                    }
                    zp->color = Color::Black;
                    zpp->color = Color::Red;
                    auto zppp = zpp->parent;
                    if (!zppp) {
                        LeftRotate(std::move(root));
                    } else if (zpp == zppp->left.get()) {
                        LeftRotate(std::move(zppp->left));
                    } else {
                        LeftRotate(std::move(zppp->right));
                    }
                }
            }
        }
        root->color = Color::Black;
    }
};

template <typename T>
std::ostream& operator<<(std::ostream& os, Node<T>* node) {
    if (node) {
        os << node->left.get();
        os << node->key;
        if (node->color == Color::Black) {
            os << "● ";
        } else {
            os << "○ ";
        }
        os << node->right.get();
    }
    return os;
}

template <typename T>
std::ostream& operator<<(std::ostream& os, const RBTree<T>& tree) {
    os << tree.root.get();
    return os;
}

int main() {
    constexpr size_t SIZE = 30;
    std::vector<int> v (SIZE);
    std::iota(v.begin(), v.end(), 1);
    std::shuffle(v.begin(), v.end(), gen);
    RBTree<int> rbtree;
    for (auto n : v) { // random order insert test
        rbtree.Insert(n);
    }
    std::cout << rbtree;
}