2023年7月4日

手写智能指针:从 shared_ptr 到 weak_ptr 再到循环引用


我的立场: 手写智能指针是我见过最好的 C++ 入门练习之一。不是因为你会在生产中用到它,而是因为写完之后你会真正理解 RAII 和引用计数到底在干什么。

我的偏见: 我觉得很多人用 shared_ptr 用得太随意了,到处 shared 其实是设计没想清楚。但这不影响你需要理解它。


先写 shared_ptr

核心就两个成员:一个裸指针,一个引用计数。拷贝的时候计数加一,析构的时候计数减一,减到零就 delete。

就这么简单。

#include <iostream>
#include <string>
template <class T>
class shared_ptr_t {
private:
T* ptr_;
int* count_;
void release() {
if (ptr_ && (--(*count_)) == 0) {
delete ptr_;
delete count_;
}
}
void add_count(shared_ptr_t<T>& sp) {
ptr_ = sp.ptr_;
count_ = sp.count_;
if (ptr_) {
++(*count_);
}
}
public:
shared_ptr_t(T* ptr = nullptr) : ptr_(ptr) {
if (ptr_) {
count_ = new int(1);
}
}
shared_ptr_t(shared_ptr_t& sp) { add_count(sp); }
shared_ptr_t<T>& operator=(shared_ptr_t<T>& sp) {
if (this != &sp) {
release();
add_count(sp);
}
return *this;
}
~shared_ptr_t() {
printf("~shared_ptr_t count: %d\n", *count_);
release();
}
T& operator*() { return *ptr_; }
T* operator->() { return ptr_; }
T* get() { return ptr_; }
int use_count() { return *count_; }
int* count() { return count_; }
};

试一下:

int main() {
shared_ptr_t<std::string> sp1{new std::string("Hello")};
shared_ptr_t<std::string> sp2{sp1};
shared_ptr_t<std::string> sp3{new std::string("World")};
printf("sp1 use count: %d\n", sp1.use_count()); // 2
sp3 = sp2;
printf("sp1 use count: %d\n", sp1.use_count()); // 3
return 0;
}

析构的时候会看到计数从 3 递减到 0,最后一个析构的负责 delete。没什么魔法。

再写 weak_ptr

weak_ptr 存在的唯一理由:打破循环引用。

它不增加引用计数,只是”观察”一个 shared_ptr 管理的对象。想用的时候调 lock(),对象还活着就返回指针,死了就返回 nullptr。

template <class T>
class weak_ptr_t {
private:
int* count_;
T* ptr_;
public:
weak_ptr_t(T* ptr = nullptr) : count_(nullptr), ptr_(nullptr) {}
weak_ptr_t(shared_ptr_t<T>& sp) : count_(sp.count()), ptr_(sp.get()) {}
weak_ptr_t<T>& operator=(shared_ptr_t<T>& sp) {
count_ = sp.count();
ptr_ = sp.get();
return *this;
}
T* lock() {
if (count_ && *count_ > 0) {
return ptr_;
}
return nullptr;
}
int use_count() { return *count_; }
int* count() { return count_; }
};

注意这个实现是简化版的。标准库的 weak_ptr 还有单独的 weak count,这里省略了。够说明问题就行。

循环引用:shared_ptr 的经典翻车现场

两个节点互相持有对方的 shared_ptr,引用计数永远不会归零,内存泄漏。

class ListNode {
public:
int val;
// 把 shared_ptr_t 换成 weak_ptr_t 就能解决
#if false
shared_ptr_t<ListNode> next;
shared_ptr_t<ListNode> previous;
#else
weak_ptr_t<ListNode> next;
weak_ptr_t<ListNode> previous;
#endif
ListNode(int x) : val(x), next(nullptr), previous(nullptr) {}
};
int main() {
shared_ptr_t<ListNode> p1(new ListNode(800));
shared_ptr_t<ListNode> p2(new ListNode(800));
p1->next = p2;
p2->previous = p1;
printf("p1 use count: %d\n", p1.use_count());
printf("p2 use count: %d\n", p2.use_count());
}

用 shared_ptr_t 的时候,析构时 count 是 2,永远不会到 0。换成 weak_ptr_t,count 就是 1,析构正常释放。

#if false 改成 #if true 自己跑一下就知道区别了。


写完这三个东西,你对 C++ 内存管理的理解会上一个台阶。不是因为代码多复杂,而是因为你亲手处理了”谁拥有这块内存”这个问题。

生产中当然用标准库的。但标准库的实现你看不懂的时候,回来看看这个就行。