C++ 智能指针实现之shared_ptr

发布时间:2024年01月17日

前言

智能指针本质上并不神秘,其实就是 RAII 资源管理功能的自然展现而已。本文将介绍如何实现 C++中智能指针的 shared_ptr。

原理简介

多个不同的 shared_ptr 不仅可以共享一个对象,在共享同一对象时也需要同时共享同一个计数。当最后一个指向对象(和共享计数)的 shared_ptr 析构时,它需要删除对象和共享计数。

实现过程

我们先实现共享计数的接口,这个 shared_count 类除构造函数之外有三个方法:一个增加计数,一个减少计数,一个获取计数。

class shared_count {
public:
  shared_count() : count_(1) {}
  void add_count()
  {
    ++count_;
  }
  long reduce_count()
  {
    return --count_;
  }
  long get_count() const
  {
    return count_;
  }

private:
  long count_;
};

接下来可以实现我们的引用计数智能指针了。首先是构造函数、析构函数和私有成员变量:

template <typename T>
class shared_ptr {
public:
  explicit shared_ptr(T* ptr = nullptr)
    : ptr_(ptr)
  {
    if (ptr) {
      shared_count_ = new shared_count();
    }
  }
  ~shared_ptr()
  {
    if (ptr_ && !shared_count_->reduce_count()) {
      delete ptr_;
      delete shared_count_;
    }
  }

private:
  T* ptr_;
  shared_count* shared_count_;
};

构造函数会构造一个 shared_count 出来。析构函数在看到 ptr_ 非空时(此时根据代码逻辑, shared_count 也必然非空),需要对引用数减一,并在引用数降到零时彻底删除对象和共享计数。

为了方便实现赋值(及其他一些惯用法),我们需要一个新的 swap 成员函数:

void swap(shared_ptr& rhs)
{
  using std::swap;
  swap(ptr_, rhs.ptr_);
  swap(shared_count_, rhs.shared_count_);
}

赋值函数,拷贝构造和移动构造函数的实现:

shared_ptr(const shared_ptr& other)
{
  ptr_ = other.ptr_;
  if (ptr_) {
    other.shared_count_->add_count();
    shared_count_ = other.shared_count_;
  }
}
template <typename U>
shared_ptr(const shared_ptr<U>& other) noexcept
{
  ptr_ = other.ptr_;
  if (ptr_) {
    other.shared_count_->add_count();
    shared_count_ = other.shared_count_;
  }
}
template <typename U>
shared_ptr(shared_ptr<U>&& other) noexcept
{
  ptr_ = other.ptr_;
  if (ptr_) {
    shared_count_ = other.shared_count_;
    other.ptr_ = nullptr;
  }
}

除复制指针之外,对于拷贝构造的情况,我们需要在指针非空时把引用数加一,并复制共享计数的指针。对于移动构造的情况,我们不需要调整引用数,直接把 other.ptr_ 置为空,认为 other 不再指向该共享对象即可。

不过,上面的代码有个问题:它不能正确编译。编译器会报错,像:

fatal error: ‘ptr_’ is a private member of ‘shared_ptr

错误原因是模板的各个实例间并不天然就有 friend 关系,因而不能互访私有成员 ptr_shared_count_。我们需要在 shared_ptr 的定义中显式声明:

template <typename U>
friend class shared_ptr;

返回引用计数值

接下来,创建一个返回引用计数值的函数

long use_count() const
{
  if (ptr_) {
    return shared_count_
      ->get_count();
  } else {
    return 0;
  }
}

指针类型转换

对应于 C++ 里的不同的类型强制转换:

  • static_cast
  • reinterpret_cast
  • const_cast
  • dynamic_cast

智能指针需要实现类似的函数模板。实现本身并不复杂,但为了实现这些转换,我们需要添加构造函数,允许在对智能指针内部的指针对象赋值时,使用一个现有的智能指针的共享计数。如下所示:

template <typename U>
shared_ptr(const shared_ptr<U>& other, T* ptr)
{
    ptr_ = ptr;
    if (ptr_) {
        other.shared_count_->add_count();
        shared_count_ = other.shared_count_;
    }
}

这样我们就可以实现转换所需的函数模板了。下面实现一个 dynamic_pointer_cast 来示例一下:

template <typename T, typename U>
shared_ptr<T> dynamic_pointer_cast(const shared_ptr<U>& other)
{
  T* ptr = dynamic_cast<T*>(other.get());
  return shared_ptr<T>(other, ptr);
}

验证

我们可以用下面的代码来验证一下它的功能正常:

#include <iostream>
class shape {
public:
  virtual ~shape() {}
};

class circle : public shape {
public:
  ~circle() { std::cout<<"~circle()\n"; }
};

int main()
{
  shared_ptr<circle> ptr1(new circle());
  std::cout << "use count of ptr1 is" << ptr1.use_count() << "\n";
  shared_ptr<shape> ptr2;
  std::cout << "use count of ptr2 was " << ptr2.use_count() << "\n";
  ptr2 = ptr1;
  std::cout << "use count of ptr2 is now " << ptr2.use_count() << "\n";
  if (ptr1) {
    std::cout<<"ptr1 is not empty\n";
  }

  shared_ptr<circle> ptr3 = dynamic_pointer_cast<circle>(ptr2);
  std::cout << "use count of ptr3 is " << ptr3.use_count() << "\n";
}

输出:

use count of ptr1 is1
use count of ptr2 was 0
use count of ptr2 is now 2
ptr1 is not empty
use count of ptr3 is 3
~circle()

完整代码

#include <utility>  // std::swap

class shared_count {
public:
  shared_count() noexcept : count_(1) {}
  void add_count() noexcept
  {
    ++count_;
  }
  long reduce_count() noexcept
  {
    return --count_;
  }
  long get_count() const noexcept
  {
    return count_;
  }

private:
  long count_;
};

template <typename T>
class shared_ptr {
public:
  template <typename U>
  friend class shared_ptr;

  explicit shared_ptr(T* ptr = nullptr) : ptr_(ptr)
  {
    if (ptr) {
      shared_count_ = new shared_count();
    }
  }
  ~shared_ptr()
  {
    if (ptr_ && !shared_count_->reduce_count()) {
      delete ptr_;
      delete shared_count_;
    }
  }

  shared_ptr(const shared_ptr& other)
  {
    ptr_ = other.ptr_;
    if (ptr_) {
      other.shared_count_->add_count();
      shared_count_ = other.shared_count_;
    }
  }
  template <typename U>
  shared_ptr(const shared_ptr<U>& other) noexcept
  {
    ptr_ = other.ptr_;
    if (ptr_) {
      other.shared_count_->add_count();
      shared_count_ = other.shared_count_;
    }
  }
  template <typename U>
  shared_ptr(shared_ptr<U>&& other) noexcept
  {
    ptr_ = other.ptr_;
    if (ptr_) {
      shared_count_ = other.shared_count_;
      other.ptr_ = nullptr;
    }
  }
  template <typename U>
  shared_ptr(const shared_ptr<U>& other, T* ptr) noexcept
  {
    ptr_ = ptr;
    if (ptr_) {
      other.shared_count_->add_count();
      shared_count_ = other.shared_count_;
    }
  }
  shared_ptr& operator=(shared_ptr rhs) noexcept
  {
    rhs.swap(*this);
    return *this;
  }

  T* get() const noexcept
  {
    return ptr_;
  }
  long use_count() const noexcept
  {
    if (ptr_) {
      return shared_count_->get_count();
    }
    else {
      return 0;
    }
  }
  void swap(shared_ptr& rhs) noexcept
  {
    using std::swap;
    swap(ptr_, rhs.ptr_);
    swap(shared_count_, rhs.shared_count_);
  }

  T& operator*() const noexcept
  {
    return *ptr_;
  }
  T* operator->() const noexcept
  {
    return ptr_;
  }
  operator bool() const noexcept
  {
    return ptr_;
  }

private:
  T* ptr_;
  shared_count* shared_count_;
};

template <typename T>
void swap(shared_ptr<T>& lhs, shared_ptr<T>& rhs) noexcept
{
  lhs.swap(rhs);
}

template <typename T, typename U>
shared_ptr<T> static_pointer_cast(const shared_ptr<U>& other) noexcept
{
  T* ptr = static_cast<T*>(other.get());
  return shared_ptr<T>(other, ptr);
}

template <typename T, typename U>
shared_ptr<T> reinterpret_pointer_cast(const shared_ptr<U>& other) noexcept
{
  T* ptr = reinterpret_cast<T*>(other.get());
  return shared_ptr<T>(other, ptr);
}

template <typename T, typename U>
shared_ptr<T> const_pointer_cast(const shared_ptr<U>& other) noexcept
{
  T* ptr = const_cast<T*>(other.get());
  return shared_ptr<T>(other, ptr);
}

template <typename T, typename U>
shared_ptr<T> dynamic_pointer_cast(const shared_ptr<U>& other) noexcept
{
  T* ptr = dynamic_cast<T*>(other.get());
  return shared_ptr<T>(other, ptr);
}

在代码里加了不少 noexcept。这对这个智能指针在它的目标场景能正确使用是十分必要的。

总结

我们实现了一个基本完整的带引用计数的shared_ptr智能指针。从而对智能指针有一个比较深入的理解。当然,这里与标准的std::shared_ptr还欠缺一些东西,比如多线程安全、不支持自定义删除器以及和std::weak_ptr的配合。

《现代 C++编程实战》

文章来源:https://blog.csdn.net/no_say_you_know/article/details/135631429
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。