2023年7月4日
共享指针循环引用使用弱指针解决
#include <iostream>#include <memory>#include <string>/// @brief 手写共享指针/// @tparam T 指针内存类型template <class T>class shared_ptr_t {private: T* ptr_; int* count_;
private: /// @brief 释放内存 void release() { if (ptr_ && (--(*count_)) == 0) { delete ptr_; delete count_; } } /// @brief 根据传进来的指针,增加引用计数 /// @param sp 指针 void add_count(shared_ptr_t<T>& sp) { ptr_ = sp.ptr_; count_ = sp.count_; if (ptr_) { ++(*count_); } }
public: /// @brief 构造函数 /// @param ptr 指针 shared_ptr_t(T* ptr = nullptr) : ptr_(ptr) { if (ptr_) { count_ = new int(1); } } /// @brief 拷贝构造函数 /// @param sp 指针 shared_ptr_t(shared_ptr_t& sp) { add_count(sp); } /// @brief 赋值函数 /// @param sp 指针 /// @return shared_ptr_t<T>& shared_ptr_t<T>& operator=(shared_ptr_t<T>& sp) { if (this != &sp) { release(); add_count(sp); } return *this; } /// @brief 析构函数 ~shared_ptr_t() { printf("destruct shared_ptr count: %d, 若解决循环引用的问题最后一行应该显示为1\n", *count_); release(); } /// @brief 重载*和->操作符 /// @return T& T& operator*() { return *ptr_; } T* operator->() { return ptr_; } T* get() { return ptr_; } int use_count() { return *count_; } int* count() { return count_; }};
/// @brief 手写弱指针/// @tparam T 指针内存类型template <class T>class weak_ptr_t {private: int* count_; T* ptr_;
public: /// @brief 弱指针构造函数 /// @param ptr 传入的指针 weak_ptr_t(T* ptr = nullptr) : count_(nullptr), ptr_(nullptr) {} /// @brief 拷贝构造函数,需要搭配手写shared_ptr_t使用,获取共享指针的count指针和ptr指针 /// @param sp 传入的指针 weak_ptr_t(shared_ptr_t<T>& sp) : count_(sp.count()), ptr_(sp.get()) {} /// @brief 赋值函数 /// @param sp 传入的指针 /// @return weak_ptr_t<T>& weak_ptr_t<T>& operator=(shared_ptr_t<T>& sp) { count_ = sp.count(); ptr_ = sp.get(); return *this; } /// @brief lock函数 /// @return 若引用计数大于0,返回指针,否则返回空指针 T* lock() { if (count_ && *count_ > 0) { return ptr_; } return nullptr; } /// @brief 获取引用计数 /// @return int 引用计数 int use_count() { return *count_; } /// @brief 获取引用计数指针 /// @return int* 引用计数指针 int* count() { return count_; }};
class ListNode {public: int val;#if false shared_ptr_t<ListNode> next; shared_ptr_t<ListNode> previous;#else weak_ptr_t<ListNode> next; weak_ptr_t<ListNode> previous;#endif
ListNode(int x) : val(x), next(nullptr), previous(nullptr) {}};
int main() { // 使用weak_ptr解决shared指针循环引用的问题 shared_ptr_t<ListNode> p3(new ListNode(800)); shared_ptr_t<ListNode> p4(new ListNode(800)); p3->next = p4; p4->previous = p3; printf("p3 use count: %d\n", p3.use_count()); printf("p4 use count: %d\n", p4.use_count());}