C++ 手撸一个智能指针

       最近,我写了一个简单的shared_ptr,在这里分享一波。
       首先定义一个主管引用计数的类:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class SharedCount
{
public:
SharedCount() : count_(1) {}
void add() { ++count_; }
void minus() { --count_; }
int get() const { return count_; }
private:
std::atomic<int> count_;
};

       然后就是SharedPtr类,首先在构造函数中创建SharedCount的对象:

1
2
3
4
5
6
7
8
9
10
11
template <typename T>
class SharedPtr
{
public:
SharedPtr() : ptr_(nullptr), ref_count_(new SharedCount) {}
SharedPtr(T* ptr) : ptr_(ptr), ref_count_(new SharedCount) {}
private:
T* ptr_;
SharedCount* ref_count_;
};

       通过构造函数创建出来的SharedPtr引用计数肯定是1,那析构函数怎么实现?无非就是将引用计数减1,如果引用计数最终减到0,则释放所有指针:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
...
~SharedPtr() { clean(); }
private:
void clean()
{
if(ref_count_)
{
ref_count_->minus();
if(ref_count_->get() == 0)
{
if(ptr_) delete ptr_;
delete ref_count_;
}
}
}
...

       然后就是智能指针的关键部分,即在拷贝构造和拷贝赋值的时候将引用计数+1:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
...
SharedPtr(const SharedPtr& p)
{
this->ptr_ = p.ptr_;
this->ref_count_ = p.ref_count_;
ref_count_->add();
}
SharedPtr& operator=(const SharedPtr& p)
{
clean();
this->ptr_ = p.ptr_;
this->ref_count_ = p.ref_count_;
ref_count_->add();
return *this;
}
...

       处理了拷贝语义,还需要处理移动语义,即实现移动构造和移动赋值函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
...
SharedPtr(SharedPtr&& p)
{
this->ptr_ = p.ptr_;
this->ref_count_ = p.ref_count_;
p.ptr_ = nullptr;
p.ref_count_ = nullptr;
}
SharedPtr& operator=(SharedPtr&& p)
{
clean();
this->ptr_ = p.ptr_;
this->ref_count_ = p.ref_count_;
p.ptr_ = nullptr;
p.ref_count_ = nullptr;
return *this;
}
...

       在移动语义中,引用计数保持不变,同时清空原参数中的指针。
       关于共享指针,到这里基本逻辑都已经实现完成,但还需要补充获取裸指针、获取引用计数等接口:

1
2
3
4
5
6
7
8
9
10
11
12
13
...
int use_count() { return ref_count_->get(); }
T* get() const { return ptr_; }
T* operator->() const { return ptr_; }
T& operator*() const { return *ptr_; }
operator bool() const { return ptr_; }
...

       这样一个完整的智能指针大体已经实现完成,运行一下看看:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
struct A
{
A() { std::cout << "A() \n"; }
~A() { std::cout << "~A() \n"; }
};
int main()
{
A* a = new A;
SharedPtr<A> ptr(a);
{
std::cout << ptr.use_count() << std::endl;
SharedPtr<A> b = ptr;
std::cout << ptr.use_count() << std::endl;
SharedPtr<A> c = ptr;
std::cout << ptr.use_count() << std::endl;
SharedPtr<A> d = std::move(b);
std::cout << ptr.use_count() << std::endl;
}
std::cout << ptr.use_count() << std::endl;
return 0;
}

       结果为:

1
2
3
4
5
6
7
A()
1
2
3
3
1
~A()

       基本的shared_ptr完成后,再来写点有意思的,不知道大家有没有用过这几个指针转换函数:

1
2
3
4
5
6
7
8
9
10
11
template<class T, class U>
std::shared_ptr<T> static_pointer_cast(const std::shared_ptr<U>& r) noexcept;
template<class T, class U>
std::shared_ptr<T> const_pointer_cast(const std::shared_ptr<U>& r) noexcept;
template<class T, class U>
std::shared_ptr<T> dynamic_pointer_cast(const std::shared_ptr<U>& r) noexcept;
template<class T, class U>
std::shared_ptr<T> reinterpret_pointer_cast(const std::shared_ptr<U>& r) noexcept;

       我默认大家已经知道这几个函数的作用,这里直接研究一下它的实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
...
template <typename U>
SharedPtr(const SharedPtr<U>& p, T* ptr)
{
this->ptr_ = ptr;
this->ref_count_ = p.ref_count_;
ref_count_->add();
}
...
template <typename T, typename U>
SharedPtr<T> static_pointer_cast(const SharedPtr<U>& p) noexcept
{
T* ptr = static_cast<T*>(p.get());
return SharedPtr<T>(p, ptr);
}

       SharedPtr<T>SharedPtr<U>不是一个类,所以上面的代码会稍微有点问题,p无法访问它的private成员变量ref_count,那怎么解决呢?上友元:

1
2
3
4
5
6
...
template <typename U>
friend class SharedPtr;
...

       再测试一下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
struct A
{
A() { std::cout << "A() \n"; }
~A() { std::cout << "~A() \n"; }
};
struct B : public A
{
B() { std::cout << "B() \n"; }
~B() { std::cout << "~B() \n"; }
};
void test_cast_shared()
{
B* a = new B;
SharedPtr<B> ptr(a);
{
std::cout << ptr.use_count() << std::endl;
SharedPtr<A> b = static_pointer_cast<A>(ptr);
std::cout << ptr.use_count() << std::endl;
SharedPtr<B> c = ptr;
std::cout << ptr.use_count() << std::endl;
SharedPtr<B> d = ptr;
std::cout << ptr.use_count() << std::endl;
}
std::cout << ptr.use_count() << std::endl;
}
int main()
{
test_cast_shared();
}

       结果为:

1
2
3
4
5
6
7
8
9
A()
B()
1
2
3
4
1
~B()
~A()

       上面只实现了static_pointer_cast,其他xxx_pointer_cast的原理类似,大家应该也明白了吧。
       到这里已经实现了一个简单的shared_ptr和unique_ptr,希望对大家有所帮助,完整代码见这里:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
#include <atomic>
#include <iostream>
class SharedCount
{
public:
SharedCount() : count_(1) {}
void add() { ++count_; }
void minus() { --count_; }
int get() const { return count_; }
private:
std::atomic<int> count_;
};
template <typename T>
class SharedPtr
{
public:
SharedPtr() : ptr_(nullptr), ref_count_(new SharedCount) {}
SharedPtr(T* ptr) : ptr_(ptr), ref_count_(new SharedCount) {}
~SharedPtr() { clean(); }
template <typename U>
friend class SharedPtr;
template <typename U>
SharedPtr(const SharedPtr<U>& p, T* ptr)
{
this->ptr_ = ptr;
this->ref_count_ = p.ref_count_;
ref_count_->add();
}
SharedPtr(const SharedPtr& p)
{
this->ptr_ = p.ptr_;
this->ref_count_ = p.ref_count_;
ref_count_->add();
}
SharedPtr& operator=(const SharedPtr& p)
{
clean();
this->ptr_ = p.ptr_;
this->ref_count_ = p.ref_count_;
ref_count_->add();
return *this;
}
SharedPtr(SharedPtr&& p)
{
this->ptr_ = p.ptr_;
this->ref_count_ = p.ref_count_;
p.ptr_ = nullptr;
p.ref_count_ = nullptr;
}
SharedPtr& operator=(SharedPtr&& p)
{
clean();
this->ptr_ = p.ptr_;
this->ref_count_ = p.ref_count_;
p.ptr_ = nullptr;
p.ref_count_ = nullptr;
return *this;
}
int use_count() { return ref_count_->get(); }
T* get() const { return ptr_; }
T* operator->() const { return ptr_; }
T& operator*() const { return *ptr_; }
operator bool() const { return ptr_; }
private:
void clean()
{
if(ref_count_)
{
ref_count_->minus();
if(ref_count_->get() == 0)
{
if(ptr_) delete ptr_;
delete ref_count_;
}
}
}
T* ptr_;
SharedCount* ref_count_;
};
template <typename T, typename U>
SharedPtr<T> static_pointer_cast(const SharedPtr<U>& p) noexcept
{
T* ptr = static_cast<T*>(p.get());
return SharedPtr<T>(p, ptr);
}
template <typename T, typename U>
SharedPtr<T> const_pointer_cast(const SharedPtr<U>& p) noexcept
{
T* ptr = const_cast<T*>(p.get());
return SharedPtr<T>(p, ptr);
}
template <typename T, typename U>
SharedPtr<T> dynamic_pointer_cast(const SharedPtr<U>& p) noexcept
{
T* ptr = dynamic_cast<T*>(p.get());
return ptr == nullptr ? SharedPtr<T>() : SharedPtr<T>(p, ptr);
}
template <typename T, typename U>
SharedPtr<T> reinterpret_pointer_cast(const SharedPtr<U>& p) noexcept
{
T* ptr = reinterpret_cast<T*>(p.get());
return SharedPtr<T>(p, ptr);
}
struct A
{
A() { std::cout << "A() \n"; }
~A() { std::cout << "~A() \n"; }
};
struct B : public A
{
B() { std::cout << "B() \n"; }
~B() { std::cout << "~B() \n"; }
};
void test_main()
{
A* a = new A;
SharedPtr<A> ptr(a);
{
std::cout << ptr.use_count() << std::endl;
SharedPtr<A> b = ptr;
std::cout << ptr.use_count() << std::endl;
SharedPtr<A> c = ptr;
std::cout << ptr.use_count() << std::endl;
SharedPtr<A> d = std::move(b);
std::cout << ptr.use_count() << std::endl;
}
std::cout << ptr.use_count() << std::endl;
}
void test_static_cast_shared()
{
B* a = new B;
SharedPtr<B> ptr(a);
{
std::cout << ptr.use_count() << std::endl;
SharedPtr<A> b = static_pointer_cast<A>(ptr);
std::cout << ptr.use_count() << std::endl;
SharedPtr<B> c = ptr;
std::cout << ptr.use_count() << std::endl;
SharedPtr<B> d = ptr;
std::cout << ptr.use_count() << std::endl;
}
std::cout << ptr.use_count() << std::endl;
}
int main()
{
test_main();
std::cout << "==================" << std::endl;
test_static_cast_shared();
return 0;
}

文章目录