【python3】从库存管理分析分布式锁

发布时间:2024年01月09日

分布式锁是一种用于协调多个进程或线程之间访问共享资源的机制,它可以避免多个进程或线程同时对共享资源进行修改而导致的数据不一致问题。在分布式系统中,由于数据的分散存储在不同的节点上,因此需要一种可靠的分布式锁机制。

分布式锁通常需要满足以下几个条件:

  1. 互斥性:在任何时刻,只能有一个进程或线程获得锁。
  2. 安全性:一旦一个进程或线程获得锁,其他进程或线程无法修改该锁的状态,只有锁的持有者可以释放锁。
  3. 高可用性:分布式锁应该具有高可用性,即当某个节点或进程故障时,其他节点或进程可以接管该锁。
  4. 性能:分布式锁应该具有高性能,即在高并发的情况下,锁的获取和释放应该尽量快速
    当多个协程/线程/进程同时读写一个共享资源时,如果没有锁的情况下,会造成数据损坏。

例如一个秒杀活动中,商品123的库存都是100件,同时有2人参与秒杀(贱笑了*-*);假设有2个进程/线程/协程同一时刻对秒杀库存进行读写,各自将库存数目按照订单减库存,那么库存的中商品的数目最终会是多少呢?

并发下扣减库存

库存表

以下所有测试用例,均将数据还原至以上初始数据

实验环境

mysql 
redis
python3及peewee库

不加锁的情况

#! -*-conding=: UTF-8 -*-
# 2023/8/10 19:04
import random
import time
from datetime import datetime
import threading

from peewee import *
from playhouse.shortcuts import ReconnectMixin
from playhouse.pool import PooledMySQLDatabase


class ReconnectMySQLDatabase(ReconnectMixin, PooledMySQLDatabase):
    pass


db = ReconnectMySQLDatabase("inventory", host="192.168.91.1", port=3306, user="root", password="root")


# 删除 - 物理删除和逻辑删除 - 物理删除  -假设你把某个用户数据 - 用户购买记录,用户的收藏记录,用户浏览记录啊
# 通过save方法做了修改如何确保只修改update_time值而不是修改add_time
class BaseModel(Model):
    add_time = DateTimeField(default=datetime.now, verbose_name="添加时间")
    is_deleted = BooleanField(default=False, verbose_name="是否删除")
    update_time = DateTimeField(verbose_name="更新时间", default=datetime.now)

    def save(self, *args, **kwargs):
        # 判断这是一个新添加的数据还是更新的数据
        if self._pk is not None:
            # 这是一个新数据
            self.update_time = datetime.now()
        return super().save(*args, **kwargs)

    @classmethod
    def delete(cls, permanently=False):  # permanently表示是否永久删除
        if permanently:
            return super().delete()
        else:
            return super().update(is_deleted=True)

    def delete_instance(self, permanently=False, recursive=False, delete_nullable=False):
        if permanently:
            return self.delete(permanently).where(self._pk_expr()).execute()
        else:
            self.is_deleted = True
            self.save()

    @classmethod
    def select(cls, *fields):
        return super().select(*fields).where(cls.is_deleted == False)

    class Meta:
        database = db


class Inventory(BaseModel):
    # 商品的库存表
    # stock = PrimaryKeyField(Stock)
    goods = IntegerField(verbose_name="商品id", unique=True)
    stocks = IntegerField(verbose_name="库存数量", default=0)
    version = IntegerField(verbose_name="版本号", default=0)  # 用于分布式锁的乐观锁


def sell():
    # 多线程下的并发带来的数据不一致的问题
    goods_list = [(1, 10), (2, 20), (3, 30)]
    with db.atomic() as txn:
        # 超卖
        for goods_id, num in goods_list:
            # 查询库存
            goods_inv = Inventory.get(Inventory.goods == goods_id)
            print(f"商品{goods_id} 售出 {num}件")
            time.sleep(random.randint(1, 3))  # 增加并发问题的拟态实现

            if goods_inv.stocks < num:
                print(f"商品:{goods_id} 库存不足")
                txn.rollback()
                break
            else:
                goods_inv.stocks -= num
                goods_inv.save()


def create_data():
    db.create_tables([Inventory])
    for i in range(5):
        goods_inv = Inventory(goods=i, stocks=100)
        goods_inv.save()


if __name__ == "__main__":
    # create_data()

    t1 = threading.Thread(target=sell)
    t2 = threading.Thread(target=sell)
    t1.start()
    t2.start()

    t1.join()
    t2.join()

输出结果为:

商品1 售出 10件
商品1 售出 10件
商品2 售出 20件
商品3 售出 30件
商品2 售出 20件
商品3 售出 30

  1. 数据库连接和配置: 代码中使用了 PooledMySQLDatabase 对象连接MySQL数据库,具体配置为连接到IP地址为 192.168.91.1 的MySQL服务器,使用用户名 root 和密码 root 连接到数据库 inventory
  2. 基础模型 BaseModel: 定义了一个基础模型 BaseModel,其中包含了添加时间、更新时间、是否删除等字段的定义。此模型有一些方法,如保存(save)、删除(delete)和查询(select)等。
  3. 库存模型 Inventory: 基于 BaseModel 定义了一个库存模型 Inventory,其中包含商品id、库存数量、版本号等字段的定义。
  4. 库存售卖函数 sell: 这个函数用于模拟售卖商品的操作,使用多线程处理多个商品的售卖。在售卖过程中,先查询库存,然后根据库存数量进行扣减,但存在超卖问题,因为并发情况下会导致库存不足。
  5. 创建数据函数 create_data: 这个函数用于创建初始的商品库存数据,将5种商品的库存数量都设置为100。
  6. 多线程处理售卖操作: 代码主要在 __name__ == "__main__" 的分支中运行。首先通过 create_data() 函数创建初始库存数据,然后使用两个线程并发运行 sell() 函数模拟售卖商品。但由于多线程并发问题,可能会导致库存不足和超卖等问题。

问题:应该在更新的时候根据当前的数据更新。

不加锁(根据实时数据扣减库存)

修改代码如下(修改了更新数据的逻辑):

#! -*-conding=: UTF-8 -*-
# 2023/8/10 19:04
import random
import time
from datetime import datetime
import threading

from peewee import *
from playhouse.shortcuts import ReconnectMixin
from playhouse.pool import PooledMySQLDatabase


class ReconnectMySQLDatabase(ReconnectMixin, PooledMySQLDatabase):
    pass


db = ReconnectMySQLDatabase("inventory", host="192.168.91.1", port=3306, user="root", password="root")


# 删除 - 物理删除和逻辑删除 - 物理删除  -假设你把某个用户数据 - 用户购买记录,用户的收藏记录,用户浏览记录啊
# 通过save方法做了修改如何确保只修改update_time值而不是修改add_time
class BaseModel(Model):
    add_time = DateTimeField(default=datetime.now, verbose_name="添加时间")
    is_deleted = BooleanField(default=False, verbose_name="是否删除")
    update_time = DateTimeField(verbose_name="更新时间", default=datetime.now)

    def save(self, *args, **kwargs):
        # 判断这是一个新添加的数据还是更新的数据
        if self._pk is not None:
            # 这是一个新数据
            self.update_time = datetime.now()
        return super().save(*args, **kwargs)

    @classmethod
    def delete(cls, permanently=False):  # permanently表示是否永久删除
        if permanently:
            return super().delete()
        else:
            return super().update(is_deleted=True)

    def delete_instance(self, permanently=False, recursive=False, delete_nullable=False):
        if permanently:
            return self.delete(permanently).where(self._pk_expr()).execute()
        else:
            self.is_deleted = True
            self.save()

    @classmethod
    def select(cls, *fields):
        return super().select(*fields).where(cls.is_deleted == False)

    class Meta:
        database = db


class Inventory(BaseModel):
    # 商品的库存表
    # stock = PrimaryKeyField(Stock)
    goods = IntegerField(verbose_name="商品id", unique=True)
    stocks = IntegerField(verbose_name="库存数量", default=0)
    version = IntegerField(verbose_name="版本号", default=0)  # 用于分布式锁的乐观锁


def sell():
    # 多线程下的并发带来的数据不一致的问题
    goods_list = [(1, 99), (2, 20), (3, 30)]
    with db.atomic() as txn:
        # 超卖
        for goods_id, num in goods_list:
            # 查询库存
            goods_inv = Inventory.get(Inventory.goods == goods_id)
            
            time.sleep(random.randint(1, 3))

            if goods_inv.stocks < num:
                print(f"商品:{goods_id} 库存不足")
                txn.rollback()
                break
            else:
                # 让数据库根据自己当前的值更新数据
                query = Inventory.update(stocks=Inventory.stocks - num).where(Inventory.goods == goods_id)
                ok = query.execute()
                print(f"商品{goods_id} 售出 {num}件")
                if ok:
                    print("更新成功")
                else:
                    print("更新失败")


def create_data():
    db.create_tables([Inventory])
    for i in range(5):
        goods_inv = Inventory(goods=i, stocks=100)
        goods_inv.save()


if __name__ == "__main__":
    # create_data()

    t1 = threading.Thread(target=sell)
    t2 = threading.Thread(target=sell)
    t1.start()
    t2.start()

    t1.join()
    t2.join()

输出结果为:

商品1 售出 99件
商品1 售出 99件
更新成功
商品2 售出 20件
更新成功
商品3 售出 30件
更新成功
更新成功
商品2 售出 20件
更新成功
商品3 售出 30件
更新成功

咦,商品售出99件后为啥还能售出第二次99件?还是出现超卖现象了!!读→更新这里不是原子的。

数据库里的数据页证明了超卖了:

这还是不能处理并发问题。

加锁

单实例锁

from datetime import datetime
import threading
import time
from random import randint

from peewee import *
from playhouse.shortcuts import ReconnectMixin
from playhouse.pool import PooledMySQLDatabase


class ReconnectMySQLDatabase(ReconnectMixin, PooledMySQLDatabase):
    pass


db = ReconnectMySQLDatabase("inventory", host="192.168.91.1", port=3306, user="root", password="root")


# 删除 - 物理删除和逻辑删除 - 物理删除  -假设你把某个用户数据 - 用户购买记录,用户的收藏记录,用户浏览记录啊
# 通过save方法做了修改如何确保只修改update_time值而不是修改add_time
class BaseModel(Model):
    add_time = DateTimeField(default=datetime.now, verbose_name="添加时间")
    is_deleted = BooleanField(default=False, verbose_name="是否删除")
    update_time = DateTimeField(verbose_name="更新时间", default=datetime.now)

    def save(self, *args, **kwargs):
        # 判断这是一个新添加的数据还是更新的数据
        if self._pk is not None:
            # 这是一个新数据
            self.update_time = datetime.now()
        return super().save(*args, **kwargs)

    @classmethod
    def delete(cls, permanently=False):  # permanently表示是否永久删除
        if permanently:
            return super().delete()
        else:
            return super().update(is_deleted=True)

    def delete_instance(self, permanently=False, recursive=False, delete_nullable=False):
        if permanently:
            return self.delete(permanently).where(self._pk_expr()).execute()
        else:
            self.is_deleted = True
            self.save()

    @classmethod
    def select(cls, *fields):
        return super().select(*fields).where(cls.is_deleted == False)

    class Meta:
        database = db


class Inventory(BaseModel):
    # 商品的库存表
    # stock = PrimaryKeyField(Stock)
    goods = IntegerField(verbose_name="商品id", unique=True)
    stocks = IntegerField(verbose_name="库存数量", default=0)
    version = IntegerField(verbose_name="版本号", default=0)  # 分布式锁的乐观锁


R = threading.Lock()


def sell():
    # 多线程下的并发带来的数据不一致的问题
    goods_list = [(1, 10), (2, 20), (3, 99)]
    with db.atomic() as txn:
        # 超卖
        for goods_id, num in goods_list:
            # 查询库存
            with R:
                goods_inv = Inventory.get(Inventory.goods == goods_id)

                time.sleep(randint(1, 3))
                if goods_inv.stocks < num:
                    print(f"商品:{goods_id} 库存不足")
                    txn.rollback()
                    break
                else:
                    # 让数据库根据自己当前的值更新数据, 这个语句能不能处理并发的问题
                    query = Inventory.update(stocks=Inventory.stocks - num).where(Inventory.goods == goods_id)
                    ok = query.execute()
                    print(f"商品{goods_id} 售出 {num}件")
                    if ok:
                        print("更新成功")
                    else:
                        print("更新失败")


def create_data():
    db.create_tables([Inventory])
    for i in range(5):
        goods_inv = Inventory(goods=i, stocks=100)
        goods_inv.save()


if __name__ == "__main__":
    # create_data()

    t1 = threading.Thread(target=sell)
    t2 = threading.Thread(target=sell)
    t1.start()
    t2.start()

    t1.join()
    t2.join()

输出结果为:

商品1 售出 10件
更新成功
商品2 售出 20件
更新成功
商品3 售出 99件
更新成功
商品1 售出 10件
更新成功
商品2 售出 20件
更新成功
商品:3 库存不足

单体服务中这样实现是可以的,但是在微服务中,普通的锁机制失效。

MySQL分布式锁

基于mysql的悲观锁实现

悲观锁适用于对并发要求不高但需要确保操作的一致性的场景

  • 悲观锁概念:顾名思义,就是对于数据的处理持悲观态度,总认为会发生并发冲突,获取和修改数据时,别人会修改数据;所以在整个数据处理过程中,需要将数据锁定

  • 悲观锁的实现:通常依靠数据库提供的锁机制实现,比如mysql的排他锁,select … for update来实现悲观锁;例如,商品秒杀过程中,库存数量的减少,避免出现超卖的情况

mysql中的悲观锁实现:for update
  • mysql请求一把锁for update

  • 使用for update的时候要注意:每个语句mysql都是默认提交的

  • 需要关闭autocommit:set autocommit=0;(注意这个只针对当前窗口有效,不是全局的);(查询select @@autocommit;

  • 具体执行逻辑:select * from inventary where goods=1 for update;

  • 释放锁:commit;

for update的本质
  • 其实是行锁,只会锁住满足条件的数据,where goods=1where goods=2这2个是不会触发锁的

  • 如果条件部分没有索引goods,那么行锁会升级成表锁

  • 锁只是锁住要更新的语句for update,普通的查询不会锁住

  • 如果没有满足条件,不会锁表

使用悲观锁来实现防止超卖的效果,可以使用数据库的行级锁来保证在读取库存时进行锁定,从而避免并发问题。以下是使用悲观锁来实现的示例代码:

#! -*-conding=: UTF-8 -*-
# 2023/8/11 15:07


from datetime import datetime
import threading
from peewee import *

# 定义数据库连接
db = MySQLDatabase("inventory", host="192.168.91.1", port=3306, user="root", password="root")


# 定义基础模型
class BaseModel(Model):
    class Meta:
        database = db


# 定义库存表
class Inventory(BaseModel):
    goods = IntegerField(verbose_name="商品id", unique=True)
    stocks = IntegerField(verbose_name="库存数量", default=0)
    version = IntegerField(verbose_name="版本号", default=0)


# 初始化数据库连接
db.connect()

# 创建库存表
db.create_tables([Inventory], safe=True)


# 出售商品函数,使用悲观锁
def sell(order_list):
    with db.atomic() as txn:
        for goods_id, num in order_list:
            try:
                # 查询库存并锁定记录
                goods_inv = Inventory.select().where(Inventory.goods == goods_id).for_update().get()
                if goods_inv.stocks < num:
                    print(f"商品:{goods_id} 库存不足")
                    txn.rollback()
                    return False
                # 更新库存和版本号
                goods_inv.stocks -= num
                goods_inv.version += 1
                goods_inv.save()
                print(f"商品{goods_id} 售出 {num}件")
            except Inventory.DoesNotExist:
                print(f"商品{goods_id} 不存在")
                txn.rollback()
                return False

        print(f"订单成功,商品列表:{order_list}")
        txn.commit()
        return True


# 创建库存初始数据
def create_data():
    with db.atomic():
        for i in range(5):
            Inventory.create(goods=i, stocks=100)


# 主程序
if __name__ == "__main__":
    # create_data()

    # 模拟一批订单并发出售商品
    order_list = [(1, 10), (2, 20), (3, 30)]

    current_stock = {}
    for goods_id, _ in order_list:
        stock = Inventory.get(Inventory.goods == goods_id).stocks
        current_stock[goods_id] = stock
    print(f"当前库存:{current_stock}")

    num_threads = 2  # 可根据需要调整线程数
    threads = []
    for _ in range(num_threads):
        t = threading.Thread(target=sell, args=(order_list,))
        threads.append(t)

    for t in threads:
        t.start()

    for t in threads:
        t.join()

    # 查询库存
    final_stock = {}
    for goods_id, _ in order_list:
        stock = Inventory.get(Inventory.goods == goods_id).stocks
        final_stock[goods_id] = stock
    print(f"最终库存:{final_stock}")

输出结果为:

当前库存:{1: 100, 2: 100, 3: 100}
商品1 售出 10件
商品2 售出 20件
商品3 售出 30件
订单成功,商品列表:[(1, 10), (2, 20), (3, 30)]
商品1 售出 10件
商品2 售出 20件
商品3 售出 30件
订单成功,商品列表:[(1, 10), (2, 20), (3, 30)]
最终库存:{1: 80, 2: 60, 3: 40}

在这个示例代码中,我们使用了for_update()方法来锁定库存记录,确保在查询库存时不会被其他线程修改。在处理订单时,我们检查库存并更新库存和版本号。使用悲观锁的方式可以保证在处理订单过程中库存的一致性,避免超卖。

基于mysql的乐观锁

  • 乐观锁概念:乐观锁准确的说不是一种锁,而是解决数据不一致的方案
  • 乐观锁的实现原理

乐观锁适用于需要更高并发性能但可能需要做更多的冲突检测重试

import threading
import time
from peewee import *

# 定义数据库连接
db = MySQLDatabase("inventory", host="192.168.91.1", port=3306, user="root", password="root")


# 定义基础模型
class BaseModel(Model):
    class Meta:
        database = db


# 定义库存表
class Inventory(BaseModel):
    goods = IntegerField(verbose_name="商品id", unique=True)
    stocks = IntegerField(verbose_name="库存数量", default=0)
    version = IntegerField(verbose_name="版本号", default=0)


# 初始化数据库连接
db.connect()

# 创建库存表
db.create_tables([Inventory], safe=True)


# 出售商品函数,使用乐观锁
def sell2(goods_list, i):
    # 演示基于数据库的乐观锁机制
    max_retry = 3
    retry = 0
    stock = True
    while retry < max_retry:
        all_updated = True
        with db.atomic() as txn:
            for goods_id, num in goods_list:
                try:
                    goods_inv = Inventory.select().where(Inventory.goods == goods_id).get()
                    print(f"线程: {i} : 商品{goods_id} 当前库存:{goods_inv.stocks}, 版本: {goods_inv.version}")

                    if goods_inv.stocks < num:
                        print(f"商品{goods_id} 库存不足")
                        all_updated = False
                        stock = False
                        break

                    time.sleep(1)  # 模拟多个线程都读到数据,但是未更新时的状态:竞态

                    query = Inventory.update(stocks=Inventory.stocks - num, version=Inventory.version + 1).where(
                        Inventory.goods == goods_id, Inventory.version == goods_inv.version)

                    rows_updated = query.execute()

                    if rows_updated != 1:
                        print(f"线程{i} : 商品{goods_id} 更新失败,可能被其他线程修改,正在重试...")
                        all_updated = False
                        break
                except Inventory.DoesNotExist:
                    print(f"商品{goods_id} 不存在")
                    all_updated = False
                    break

            if all_updated:
                print(f"线程{i} : 所有商品更新成功")
                txn.commit()
                break
            else:
                retry += 1
                txn.rollback()
                if not stock:
                    break


# 创建库存初始数据
def create_data():
    with db.atomic():
        for i in range(5):
            Inventory.create(goods=i, stocks=100)


# 主程序
if __name__ == "__main__":
    # create_data()

    # 模拟一批订单并发出售商品
    order_list = [(1, 50), (2, 20), (3, 30)]

    current_stock = {}
    for goods_id, _ in order_list:
        stock = Inventory.get(Inventory.goods == goods_id).stocks
        current_stock[goods_id] = stock
    print(f"当前库存:{current_stock}")

    num_threads = 2  # 可根据需要调整线程数
    threads = []
    for i in range(num_threads):
        t = threading.Thread(target=sell2, args=(order_list, i))
        threads.append(t)

    for t in threads:
        t.start()

    for t in threads:
        t.join()

    # 查询库存
    final_stock = {}
    for goods_id, _ in order_list:
        stock = Inventory.get(Inventory.goods == goods_id).stocks
        final_stock[goods_id] = stock
    print(f"最终库存:{final_stock}")

输出结果为:

当前库存:{1: 100, 2: 100, 3: 100}
线程: 0 : 商品1 当前库存:100, 版本: 14
线程: 1 : 商品1 当前库存:100, 版本: 14
线程: 0 : 商品2 当前库存:100, 版本: 14
线程: 0 : 商品3 当前库存:100, 版本: 14
线程0 : 所有商品更新成功
线程1 : 商品1 更新失败,可能被其他线程修改,正在重试...
线程: 1 : 商品1 当前库存:50, 版本: 15
线程: 1 : 商品2 当前库存:80, 版本: 15
线程: 1 : 商品3 当前库存:70, 版本: 15
线程1 : 所有商品更新成功
最终库存:{1: 0, 2: 60, 3: 40}

在这个示例代码中,sell函数来接受一个订单列表order_list,每个订单包含商品ID和数量。在处理订单时,我们先检查所有商品的库存是否充足,如果充足则同时更新所有商品的库存和版本号,如果库存不足则回滚。这样可以保证订单中的所有商品在处理过程中要么全部更新成功,要么全部回滚。这种方式可以保证同一个订单的商品购买数量的一致性,并且避免超卖。

需要注意的是,由于多线程并发执行,数据库连接和操作的线程安全性需要保证。此外,乐观锁的方式在并发高的情况下可能会导致较多的重试,因此需要合理设计并发情况下的策略,确保库存更新的正确性。

在实际应用中,还需要考虑异常处理、数据库连接池的管理、线程数的设置、数据库索引的优化等因素,以确保系统的性能和稳定性。

Redis分布式锁

Redis 分布式锁中需要用到的命令:

  • SET key value [EX seconds | PX milliseconds]:设置带过期时间的key-value
  • EXPIRE key seconds:给指定的key设置过期时间。
  • GET key:获取给定的keyvalue。
  • DEL key:删除给定的key。
  • SETNX key value:如果key不存在,则设置key-value,反正设置失败。这里用SETGET一起使用代替。

使用Redis实现分布式锁来防止超卖需要借助Redis的原子性操作。我们可以使用Redis的SETNX命令(或者Redis的分布式锁实现方式,如RedLock等)来实现一个分布式锁,确保在某个时刻只有一个线程能够进行库存的检查和更新。下面是一个使用Redis分布式锁来防止超卖的示例代码,前提是你需要在Python环境中安装redis-py库。

获取锁

Redis中,一个相同的key代表一把锁。是否拥有这把锁,需要判断keyvalue是否是自己设置的,同时还要判断锁是否已经过期。

以下是某个实例加锁的步骤:

  1. 通过GET命令获取key,如果获取不到key,说明还没有加锁;
  2. 如果没有加锁,则使用SET命令设置key,同时设置锁的过期时间,加锁成功。返回;
  3. 如果获取到了key,并且value是自己设置的,证明该实例已经加锁成功,此时需要使用EXPIRE命令为锁添加过期时间,因为这次可能是重试,前一次已经加锁成功。返回;
  4. 如果获取到了key,但是value不属于自己设置的,证明已经被其他实例抢到了锁,加锁失败。
  5. 加锁失败,则继续进行 1~4 步骤,直至超时或者加锁成功。

以上 1~4 步骤需要原子性操作,可以通过lua脚本进行封装。

释放锁

  1. 主动释放:释放锁其实就是删除key,使用DEL命令进行删除。删除key前,需要判断key对应的value是否为自己设置的value,如果不是,证明锁已经被其他实例获取。判断和删除都也需要是原子操作。
  2. 过期释放:由于锁(即key)设置了过期,如果锁没有被续期(增加过期时间),就会被 Redis 删除。

需要注意的是,使用 Redis 实现分布式锁需要考虑一些问题,例如Redis实例的可用性、网络延迟、锁的持有者异常退出等,需要进行合理的设计和实现。另外,为了保证锁的正确性和可靠性,可以采用一些常用的技术手段,例如设置合适的超时时间、使用RedLock 算法、采用Lua脚本等

1.0版本

# ! -*-conding=: UTF-8 -*-
# 2023/8/11 19:47

import redis
import time
import threading
from random import randint
from datetime import datetime
from peewee import *
from playhouse.shortcuts import ReconnectMixin
from playhouse.pool import PooledMySQLDatabase


class ReconnectMySQLDatabase(ReconnectMixin, PooledMySQLDatabase):
    pass


host = "192.168.91.1"

db = ReconnectMySQLDatabase("inventory", host=host, port=3306, user="root", password="root")


# 这个BaseModel可以忽略  只不过是重写了peewee中的cudr的一些操作 继承就行了
class BaseModel(Model):
    add_time = DateTimeField(default=datetime.now, verbose_name="添加时间")
    is_deleted = BooleanField(default=False, verbose_name="是否删除")
    update_time = DateTimeField(verbose_name="更新时间", default=datetime.now)

    def save(self, *args, **kwargs):
        # 判断这是一个新添加的数据还是更新的数据
        if self._pk is not None:
            # 这是一个新数据
            self.update_time = datetime.now()
        return super().save(*args, **kwargs)

    @classmethod
    def delete(cls, permanently=False):  # permanently表示是否永久删除
        if permanently:
            return super().delete()
        else:
            return super().update(is_deleted=True)

    def delete_instance(self, permanently=False, recursive=False, delete_nullable=False):
        if permanently:
            return self.delete(permanently).where(self._pk_expr()).execute()
        else:
            self.is_deleted = True
            self.save()

    @classmethod
    def select(cls, *fields):
        return super().select(*fields).where(cls.is_deleted == False)

    class Meta:
        database = db


class Inventory(BaseModel):
    # 商品的库存表
    # stock = PrimaryKeyField(Stock)
    goods = IntegerField(verbose_name="商品id", unique=True)
    stocks = IntegerField(verbose_name="库存数量", default=0)
    version = IntegerField(verbose_name="版本号", default=0)  # 分布式锁的乐观锁  这里没用到


# 写一个redis分布式锁
class Lock:
    # 初始化
    def __init__(self, name):
        self.redis_client = redis.Redis(host=host, port=6379)
        self.name = name

    # 上锁
    def acquire(self):
        if not self.redis_client.get(self.name):
            self.redis_client.set(self.name, 1)
            return True
        else:
            while True:
                import time
                time.sleep(1)
                if self.redis_client.get(self.name):
                    return True

    # 释放锁
    def release(self):
        self.redis_client.delete(self.name)


def sell2():
    # 多线程下的并发带来的数据不一致的问题
    # 顾客(goods_list)商品id为1的买10件以此类推
    goods_list = [(1, 10), (2, 20), (3, 30)]
    # 事务
    with db.atomic() as txn:
        # 超卖
        for goods_id, num in goods_list:
            # 获取锁
            lock = Lock(f"lock:goods_{goods_id}")
            # 上锁
            lock.acquire()
            # 查询库存
            goods_inv = Inventory.get(Inventory.goods == goods_id)
            print(f"商品{goods_id} 售出 {num}件")
            time.sleep(randint(1, 3))
            if goods_inv.stocks < num:
                print(f"商品:{goods_id} 库存不足")
                txn.rollback()
                lock.release()  # 释放锁
                break
            else:
                # mysql中有修改的情况下另一个修改将无法进行 这是mysql的原子性
                query = Inventory.update(stocks=Inventory.stocks - num).where(Inventory.goods == goods_id)
                ok = query.execute()
                if ok:
                    print("更新成功")
                else:
                    print("更新失败")
            lock.release()  # 释放锁


if __name__ == '__main__':
    # 开两个线程

    t1 = threading.Thread(target=sell2)
    t2 = threading.Thread(target=sell2)
    t1.start()
    t2.start()

    t1.join()
    t2.join()

输出结果为:

商品1 售出 10件
商品1 售出 10件
更新成功
商品2 售出 20件
更新成功
商品3 售出 30件
更新成功
更新成功
商品2 售出 20件
更新成功
商品3 售出 30件
更新成功

当并发非常高的时候还是会出现超卖的情况,问题出在以下代码中,这里的get和set不是原子性的

if self.redis_client.get(self.name):
    self.redis_client.set(self.name, 1)
    return True

2.0版本

setnx版

  • 使用setnx确保获取和设置key是原子性
#! -*-conding=: UTF-8 -*-
# 2023/8/11 19:47
import redis
import time
import threading
from random import randint
from datetime import datetime
from peewee import *
from playhouse.shortcuts import ReconnectMixin
from playhouse.pool import PooledMySQLDatabase


class ReconnectMySQLDatabase(ReconnectMixin, PooledMySQLDatabase):
    pass


host = "192.168.91.1"

db = ReconnectMySQLDatabase("inventory", host=host, port=3306, user="root", password="root")


# 这个BaseModel可以忽略  只不过是重写了peewee中的cudr的一些操作 继承就行了
class BaseModel(Model):
    add_time = DateTimeField(default=datetime.now, verbose_name="添加时间")
    is_deleted = BooleanField(default=False, verbose_name="是否删除")
    update_time = DateTimeField(verbose_name="更新时间", default=datetime.now)

    def save(self, *args, **kwargs):
        # 判断这是一个新添加的数据还是更新的数据
        if self._pk is not None:
            # 这是一个新数据
            self.update_time = datetime.now()
        return super().save(*args, **kwargs)

    @classmethod
    def delete(cls, permanently=False):  # permanently表示是否永久删除
        if permanently:
            return super().delete()
        else:
            return super().update(is_deleted=True)

    def delete_instance(self, permanently=False, recursive=False, delete_nullable=False):
        if permanently:
            return self.delete(permanently).where(self._pk_expr()).execute()
        else:
            self.is_deleted = True
            self.save()

    @classmethod
    def select(cls, *fields):
        return super().select(*fields).where(cls.is_deleted == False)

    class Meta:
        database = db


class Inventory(BaseModel):
    # 商品的库存表
    # stock = PrimaryKeyField(Stock)
    goods = IntegerField(verbose_name="商品id", unique=True)
    stocks = IntegerField(verbose_name="库存数量", default=0)
    version = IntegerField(verbose_name="版本号", default=0)  # 分布式锁的乐观锁  这里没用到


# 写一个redis分布式锁
class Lock:
    # 初始化
    def __init__(self, name):
        self.redis_client = redis.Redis(host=host, port=6379)
        self.name = name

    # 上锁
    def acquire(self):
        if self.redis_client.setnx(self.name, 1):  # 如果不存在设置并且返回1,否则返回0,这是原子操作
            return True
        else:
            while True:
                import time
                time.sleep(1)
                if self.redis_client.setnx(self.name, 1):
                    return True

    # 释放锁
    def release(self):
        self.redis_client.delete(self.name)


def sell2():
    # 多线程下的并发带来的数据不一致的问题
    # 顾客(goods_list)商品id为1的买10件以此类推
    goods_list = [(1, 10), (2, 20), (3, 30)]
    # 事务
    with db.atomic() as txn:
        # 超卖
        for goods_id, num in goods_list:
            # 获取锁
            lock = Lock(f"lock:goods_{goods_id}")
            # 上锁
            lock.acquire()
            # 查询库存
            goods_inv = Inventory.get(Inventory.goods == goods_id)
            print(f"商品{goods_id} 售出 {num}件")
            time.sleep(randint(1, 3))
            if goods_inv.stocks < num:
                print(f"商品:{goods_id} 库存不足")
                txn.rollback()
                lock.release()  # 释放锁
                break
            else:
                # 让数据库根据自己当前的值更新数据, 这个语句能不能处理并发的问题
                # mysql中有修改的情况下另一个修改将无法进行 这是mysql的原子性
                query = Inventory.update(stocks=Inventory.stocks - num).where(Inventory.goods == goods_id)
                ok = query.execute()
                if ok:
                    print("更新成功")
                else:
                    print("更新失败")
            lock.release()  # 释放锁


if __name__ == '__main__':
    # 开两个线程
    t1 = threading.Thread(target=sell2)
    t2 = threading.Thread(target=sell2)
    t1.start()
    t2.start()

    t1.join()
    t2.join()

输出结果为:

商品1 售出 10件
更新成功
商品2 售出 20件
商品1 售出 10件
更新成功
商品3 售出 30件
更新成功
更新成功
商品2 售出 20件
更新成功
商品3 售出 30件
更新成功

分布式锁需要解决的问题:

互斥性:任意时刻只能有一个客户端拥有锁,不能同时多个客户端获取

安全性锁只能被持有该锁的用户删除,而不能被其他用户删除

死锁:获取锁的客户端因为某些原因而宕机,而未能释放锁,其他客户端无法获取此锁,需要有机制来避免该类问题的发生

  1. 代码异常,导致无法运行到release
  2. 你的当前服务器网络出问题 - 脑裂
  3. 断电

容错:当部分节点宕机,客户端仍能获取锁或者释放锁

如何解决上述问题的发生 - 设置过期时间

过期设置会产生新的问题

  1. 当前的线程如果在一段时间后没有执行完, 当前的程序没有执行完,然后key过期了
  • 不安全

  • 另一个线程进来以后会将当前的key给删除掉,另一个线程删除掉了本该属于我设置的值

  • 如果当前的线程没有执行完,那我的这个线程还应该在适当的时候去续租,将过期时间重新设置

    • 应该在什么时候去设置过期 - 15s的2/3的时候去续租,也就是运行10s以后去将过期时间重新设置为15s
    • 如何去定时的完成这个续租的过程 - 启动一个线程去完成

3.0版本

set版本增加过期时间及值设置为随机字符串(只能删除自己设置的锁)

#! -*-conding=: UTF-8 -*-
# 2023/8/11 19:47
import redis
import time
import threading
from random import randint
from datetime import datetime
from peewee import *
from playhouse.shortcuts import ReconnectMixin
from playhouse.pool import PooledMySQLDatabase


class ReconnectMySQLDatabase(ReconnectMixin, PooledMySQLDatabase):
    pass


host = "192.168.91.1"

db = ReconnectMySQLDatabase("inventory", host=host, port=3306, user="root", password="root")


# 这个BaseModel可以忽略  只不过是重写了peewee中的cudr的一些操作 继承就行了
class BaseModel(Model):
    add_time = DateTimeField(default=datetime.now, verbose_name="添加时间")
    is_deleted = BooleanField(default=False, verbose_name="是否删除")
    update_time = DateTimeField(verbose_name="更新时间", default=datetime.now)

    def save(self, *args, **kwargs):
        # 判断这是一个新添加的数据还是更新的数据
        if self._pk is not None:
            # 这是一个新数据
            self.update_time = datetime.now()
        return super().save(*args, **kwargs)

    @classmethod
    def delete(cls, permanently=False):  # permanently表示是否永久删除
        if permanently:
            return super().delete()
        else:
            return super().update(is_deleted=True)

    def delete_instance(self, permanently=False, recursive=False, delete_nullable=False):
        if permanently:
            return self.delete(permanently).where(self._pk_expr()).execute()
        else:
            self.is_deleted = True
            self.save()

    @classmethod
    def select(cls, *fields):
        return super().select(*fields).where(cls.is_deleted == False)

    class Meta:
        database = db


class Inventory(BaseModel):
    # 商品的库存表
    # stock = PrimaryKeyField(Stock)
    goods = IntegerField(verbose_name="商品id", unique=True)
    stocks = IntegerField(verbose_name="库存数量", default=0)
    version = IntegerField(verbose_name="版本号", default=0)  # 分布式锁的乐观锁  这里没用到


# 写一个redis分布式锁
class Lock:
    # 初始化
    def __init__(self, name, lock_id):
        self.lock_id = lock_id
        self.redis_client = redis.Redis(host=host, port=6379)
        self.name = name

    # 上锁
    def acquire(self):
        if self.redis_client.set(self.name, self.lock_id,  nx=True, ex=15):
            # 启动一个线程去定时的刷新这个过期时间,这个操作最好也是使用lua脚本
            return True
        else:
            while True:
                import time
                time.sleep(1)
                if self.redis_client.set(self.name, self.lock_id, nx=True, ex=15):
                    return True

    # 释放锁
    def release(self):
        lock_id = self.redis_client.get(self.name)
        if lock_id == self.lock_id:
            self.redis_client.delete(self.name)
        else:
            print("不能删除不属于自己的锁")


def sell2():
    # 多线程下的并发带来的数据不一致的问题
    # 顾客(goods_list)商品id为1的买10件以此类推
    goods_list = [(1, 10), (2, 20), (3, 30)]
    # 事务
    import uuid
    lock_uuid = str(uuid.uuid4())
    with db.atomic() as txn:
        # 超卖
        for goods_id, num in goods_list:
            # 获取锁
            lock = Lock(f"lock:goods_{goods_id}", lock_uuid)
            # 上锁
            lock.acquire()
            # 查询库存
            goods_inv = Inventory.get(Inventory.goods == goods_id)
            print(f"商品{goods_id} 售出 {num}件")
            time.sleep(randint(1, 3))
            if goods_inv.stocks < num:
                print(f"商品:{goods_id} 库存不足")
                txn.rollback()
                lock.release()  # 释放锁
                break
            else:
                # 让数据库根据自己当前的值更新数据, 这个语句能不能处理并发的问题
                # mysql中有修改的情况下另一个修改将无法进行 这是mysql的原子性
                query = Inventory.update(stocks=Inventory.stocks - num).where(Inventory.goods == goods_id)
                ok = query.execute()
                if ok:
                    print("更新成功")
                else:
                    print("更新失败")
            lock.release()  # 释放锁


if __name__ == '__main__':
    # 开两个线程
    t1 = threading.Thread(target=sell2)
    t2 = threading.Thread(target=sell2)
    t1.start()
    t2.start()

    t1.join()
    t2.join()

输出结果为:

商品1 售出 10件
更新成功
不能删除不属于自己的锁
商品2 售出 20件
更新成功
不能删除不属于自己的锁
商品3 售出 30件
更新成功
不能删除不属于自己的锁
商品1 售出 10件
更新成功
不能删除不属于自己的锁
商品2 售出 20件
更新成功
不能删除不属于自己的锁
商品3 售出 30件
更新成功
不能删除不属于自己的锁

释放锁的代码仍然可能存在问题,不是原子操作,可以使用lua脚本继续封装

#! -*-conding=: UTF-8 -*-
# 2023/8/11 19:47
import redis
import time
import threading
from random import randint
from datetime import datetime
from peewee import *
from playhouse.shortcuts import ReconnectMixin
from playhouse.pool import PooledMySQLDatabase


class ReconnectMySQLDatabase(ReconnectMixin, PooledMySQLDatabase):
    pass


host = "192.168.91.1"

db = ReconnectMySQLDatabase("inventory", host=host, port=3306, user="root", password="root")


# 这个BaseModel可以忽略  只不过是重写了peewee中的cudr的一些操作 继承就行了
class BaseModel(Model):
    add_time = DateTimeField(default=datetime.now, verbose_name="添加时间")
    is_deleted = BooleanField(default=False, verbose_name="是否删除")
    update_time = DateTimeField(verbose_name="更新时间", default=datetime.now)

    def save(self, *args, **kwargs):
        # 判断这是一个新添加的数据还是更新的数据
        if self._pk is not None:
            # 这是一个新数据
            self.update_time = datetime.now()
        return super().save(*args, **kwargs)

    @classmethod
    def delete(cls, permanently=False):  # permanently表示是否永久删除
        if permanently:
            return super().delete()
        else:
            return super().update(is_deleted=True)

    def delete_instance(self, permanently=False, recursive=False, delete_nullable=False):
        if permanently:
            return self.delete(permanently).where(self._pk_expr()).execute()
        else:
            self.is_deleted = True
            self.save()

    @classmethod
    def select(cls, *fields):
        return super().select(*fields).where(cls.is_deleted == False)

    class Meta:
        database = db


class Inventory(BaseModel):
    # 商品的库存表
    # stock = PrimaryKeyField(Stock)
    goods = IntegerField(verbose_name="商品id", unique=True)
    stocks = IntegerField(verbose_name="库存数量", default=0)
    version = IntegerField(verbose_name="版本号", default=0)  # 分布式锁的乐观锁  这里没用到


# 写一个redis分布式锁
class Lock:
    # 初始化
    def __init__(self, name, lock_id):
        self.lock_id = lock_id
        self.redis_client = redis.Redis(host=host, port=6379)
        self.name = name

    # 上锁
    def acquire(self):
        if self.redis_client.set(self.name, self.lock_id,  nx=True, ex=15):
            # 启动一个线程去定时的刷新这个过期时间,这个操作最好也是使用lua脚本
            return True
        else:
            while True:
                import time
                time.sleep(1)
                if self.redis_client.set(self.name, self.lock_id, nx=True, ex=15):
                    return True

    # 释放锁
    def release(self):
        # lock_id = self.redis_client.get(self.name)
        # if lock_id == self.lock_id:
        #     self.redis_client.delete(self.name)
        # else:
        #     print("不能删除不属于自己的锁")

        unlock_script = """
            if redis.call("get",KEYS[1]) == ARGV[1] then
                return redis.call("del",KEYS[1])
            else
                return 0
            end
            """
        unlock = self.redis_client.register_script(unlock_script)
        result = unlock(keys=[self.name], args=[self.lock_id])
        if result:
            return True
        else:
            print("不能删除不属于自己的锁")
            return False


def sell2():
    # 多线程下的并发带来的数据不一致的问题
    # 顾客(goods_list)商品id为1的买10件以此类推
    goods_list = [(1, 10), (2, 20), (3, 30)]
    # 事务
    import uuid
    lock_uuid = str(uuid.uuid4())
    with db.atomic() as txn:
        # 超卖
        for goods_id, num in goods_list:
            # 获取锁
            lock = Lock(f"lock:goods_{goods_id}", lock_uuid)
            # 上锁
            lock.acquire()
            # 查询库存
            goods_inv = Inventory.get(Inventory.goods == goods_id)
            print(f"商品{goods_id} 售出 {num}件")
            time.sleep(randint(1, 3))
            if goods_inv.stocks < num:
                print(f"商品:{goods_id} 库存不足")
                txn.rollback()
                lock.release()  # 释放锁
                break
            else:
                # 让数据库根据自己当前的值更新数据, 这个语句能不能处理并发的问题
                # mysql中有修改的情况下另一个修改将无法进行 这是mysql的原子性
                query = Inventory.update(stocks=Inventory.stocks - num).where(Inventory.goods == goods_id)
                ok = query.execute()
                if ok:
                    print("更新成功")
                else:
                    print("更新失败")
            lock.release()  # 释放锁


if __name__ == '__main__':
    # 开两个线程
    t1 = threading.Thread(target=sell2)
    t2 = threading.Thread(target=sell2)
    t1.start()
    t2.start()

    t1.join()
    t2.join()

输出结果为:

商品1 售出 10件
更新成功
商品1 售出 10件
商品2 售出 20件
更新成功
商品3 售出 30件
更新成功
更新成功
商品2 售出 20件
更新成功
商品3 售出 30件
更新成功

4.0版本

Redis官方库也为我们实现了分布式锁,根据需要使用即可

import threading
import redis
import time
from peewee import *

# 定义数据库连接
db = MySQLDatabase("inventory", host="192.168.91.1", port=3306, user="root", password="root")


# 定义基础模型
class BaseModel(Model):
    class Meta:
        database = db


# 定义库存表
class Inventory(BaseModel):
    goods = IntegerField(verbose_name="商品id", unique=True)
    stocks = IntegerField(verbose_name="库存数量", default=0)
    version = IntegerField(verbose_name="版本号", default=0)


# 初始化数据库连接
db.connect()

# 创建库存表
db.create_tables([Inventory], safe=True)

# 创建Redis连接
redis_client = redis.StrictRedis(host='192.168.91.1', port=6379, db=0, decode_responses=True)


# 出售商品函数,使用Redis锁
def sell(order_list):
    with db.atomic() as txn:
        # 检查所有商品库存是否充足
        for goods_id, num in order_list:
            try:
                # 查询库存,同时尝试获取Redis锁
                with redis_client.lock(f'inventory_lock:{goods_id}', blocking_timeout=10):
                    goods_inv = Inventory.select().where(Inventory.goods == goods_id).get()
                    if goods_inv.stocks < num:
                        print(f"商品:{goods_id} 库存不足")
                        txn.rollback()
                        return False
            except Inventory.DoesNotExist:
                print(f"商品{goods_id} 不存在")
                txn.rollback()
                return False

        # 更新所有商品库存和版本号
        for goods_id, num in order_list:
            query = Inventory.update(stocks=Inventory.stocks - num, version=Inventory.version + 1).where(
                Inventory.goods == goods_id)
            rows_updated = query.execute()
            if rows_updated != 1:
                print(f"商品{goods_id} 更新失败")
                txn.rollback()
                return False

        print(f"订单成功,商品列表:{order_list}")
        txn.commit()
        return True


# 创建库存初始数据
def create_data():
    with db.atomic():
        for i in range(5):
            Inventory.create(goods=i, stocks=100)


# 主程序
if __name__ == "__main__":
    # create_data()

    # 模拟一批订单并发出售商品
    order_list = [(1, 50), (2, 20), (3, 30)]

    current_stock = {}
    for goods_id, _ in order_list:
        stock = Inventory.get(Inventory.goods == goods_id).stocks
        current_stock[goods_id] = stock
    print(f"当前库存:{current_stock}")

    num_threads = 2  # 可根据需要调整线程数
    threads = []
    for _ in range(num_threads):
        t = threading.Thread(target=sell, args=(order_list,))
        threads.append(t)

    for t in threads:
        t.start()

    for t in threads:
        t.join()

    # 查询库存
    final_stock = {}
    for goods_id, _ in order_list:
        stock = Inventory.get(Inventory.goods == goods_id).stocks
        final_stock[goods_id] = stock
    print(f"最终库存:{final_stock}")

输出结果为:

当前库存:{1: 100, 2: 100, 3: 100}
订单成功,商品列表:[(1, 50), (2, 20), (3, 30)]
订单成功,商品列表:[(1, 50), (2, 20), (3, 30)]
最终库存:{1: 0, 2: 60, 3: 40}

在这个示例中,我们使用了Redis的redis-py库来实现分布式锁。在每次查询库存之前,我们尝试获取一个以商品ID命名的Redis锁。只有成功获取锁的线程才能继续执行查询和库存更新操作,其他线程会等待锁被释放。这样可以确保在并发情况下,每个商品的库存查询和更新都是串行的,从而避免超卖的问题。

这个实现使用了Redis分布式锁来防止并发更新库存,但仍然可能面临一些问题:

  1. 死锁问题:如果获取锁的线程在更新完库存之前崩溃或出现异常,那么其他线程可能会一直等待这个锁,导致死锁。为了避免死锁,可以使用带有超时的锁,确保锁在一定时间内自动释放。
  2. 性能问题:由于每个订单在更新库存时都需要获取锁,可能会导致并发性能下降。尤其是在订单量大的情况下,锁竞争会变得更加激烈,影响整体性能。
  3. 性能瓶颈:使用单个Redis服务器来管理分布式锁可能成为性能瓶颈。在高并发情况下,Redis服务器可能成为瓶颈,限制了系统的扩展性。
  4. 数据不一致问题:虽然分布式锁可以保证同一时间只有一个线程更新库存,但是如果更新操作成功后,系统发生崩溃或者异常,可能会导致库存数据和实际订单不一致。
  5. 慢查询问题:在高并发场景下,如果某个订单的库存更新操作比较慢,可能会导致其他订单的等待时间增加,影响整体系统的吞吐量。

如何续约

锁的过期时间应该设置多长?

  • 设置短了,那么业务还没完成,锁就过期了。
  • 设置长了,万一实例崩溃了,那么其它实例也长时间拿不到锁。

更严重的是,不管你设置多长,极端情况下,都会出现业务执行时间超过过期时间。

我们可以考虑在锁还没有过期的时候,再一次延长过期时间,那么:

  • 过期时间不必设置得很长,自动续约会帮我们设置好。
  • 如果实例崩溃了,则没有人再续约,过一段时间之后自然就会过期,其它实例就能拿到锁了。

续约其实就是对Rediskey延长过期时间,需要注意的时,续期也要判断锁是不是自己的,因为锁可能已经过期被其他实例获取了。

5.0版本

使用py-redis-lock

从图上看出作者和其它大多数用Redis实现分布式锁的思路类似(SET NX),但是他对每个锁多用了一个list类型键来做信号控制,如果客户端第一次尝试获取锁失败,可以选择在signal列表上阻塞一个timeout时间用来接收锁被释放的通知,Redis列表的这个特性保证了每次只有一个客户端接收到了锁释放的通知。而获取到锁的客户端在使用完后会在对应的信号列表上推送一个通知。另外,作者对锁超时还增加了一个刷新的功能来延长(Extend)对锁的占用,可以保证在持有锁的客户端上完成所有操作后才释放锁。个人认为这种设计的优点和需要注意的点如下:

优点:

  • 一方面避免客户端反复请求锁,另一方面通过list signal来让客户端决定是否要block自己;
  • 如果有设置超时,则等待超时后客户端仍然会再尝试获取一次锁而不是直接失败;
  • 这个算法不依赖客户端时间戳,也就没有time drift问题;
  • 结合Lua脚本做原子操作,如果再加上细粒度锁,个人认为基本可以满足各种高需求场景的分布式锁要求。

Warning

  • 自动刷新可能会造成饥饿问题,如果持有锁的客户端因为某种未知原因阻塞,并且开启了自动刷新锁,那其它客户端就跪了,所以需要使用者慎用刷新机制;
  • 如果没有设置超时,且持有锁的客户端无响应的情况下就会造成死锁;

源码如下

import threading
import weakref
from base64 import b64encode
from logging import getLogger
from os import urandom

from redis import StrictRedis

__version__ = '3.6.0'

logger = getLogger(__name__)

text_type = str
binary_type = bytes

# Check if the id match. If not, return an error code.
UNLOCK_SCRIPT = b"""
    if redis.call("get", KEYS[1]) ~= ARGV[1] then
        return 1
    else
        redis.call("del", KEYS[2])
        redis.call("lpush", KEYS[2], 1)
        redis.call("pexpire", KEYS[2], ARGV[2])
        redis.call("del", KEYS[1])
        return 0
    end
"""

# Covers both cases when key doesn't exist and doesn't equal to lock's id
EXTEND_SCRIPT = b"""
    if redis.call("get", KEYS[1]) ~= ARGV[1] then
        return 1
    elseif redis.call("ttl", KEYS[1]) < 0 then
        return 2
    else
        redis.call("expire", KEYS[1], ARGV[2])
        return 0
    end
"""

RESET_SCRIPT = b"""
    redis.call('del', KEYS[2])
    redis.call('lpush', KEYS[2], 1)
    redis.call('pexpire', KEYS[2], ARGV[2])
    return redis.call('del', KEYS[1])
"""

RESET_ALL_SCRIPT = b"""
    local locks = redis.call('keys', 'lock:*')
    local signal
    for _, lock in pairs(locks) do
        signal = 'lock-signal:' .. string.sub(lock, 6)
        redis.call('del', signal)
        redis.call('lpush', signal, 1)
        redis.call('expire', signal, 1)
        redis.call('del', lock)
    end
    return #locks
"""


class AlreadyAcquired(RuntimeError):
    pass


class NotAcquired(RuntimeError):
    pass


class AlreadyStarted(RuntimeError):
    pass


class TimeoutNotUsable(RuntimeError):
    pass


class InvalidTimeout(RuntimeError):
    pass


class TimeoutTooLarge(RuntimeError):
    pass


class NotExpirable(RuntimeError):
    pass


class Lock(object):
    """
    A Lock context manager implemented via redis SETNX/BLPOP.
    """
    unlock_script = None
    extend_script = None
    reset_script = None
    reset_all_script = None

    def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False, strict=True, signal_expire=1000):
        """
        :param redis_client:
            An instance of :class:`~StrictRedis`.
        :param name:
            The name (redis key) the lock should have.
        :param expire:
            The lock expiry time in seconds. If left at the default (None)
            the lock will not expire.
        :param id:
            The ID (redis value) the lock should have. A random value is
            generated when left at the default.
            Note that if you specify this then the lock is marked as "held". Acquires
            won't be possible.
        :param auto_renewal:
            If set to ``True``, Lock will automatically renew the lock so that it
            doesn't expire for as long as the lock is held (acquire() called
            or running in a context manager).
            Implementation note: Renewal will happen using a daemon thread with
            an interval of ``expire*2/3``. If wishing to use a different renewal
            time, subclass Lock, call ``super().__init__()`` then set
            ``self._lock_renewal_interval`` to your desired interval.
        :param strict:
            If set ``True`` then the ``redis_client`` needs to be an instance of ``redis.StrictRedis``.
        :param signal_expire:
            Advanced option to override signal list expiration in milliseconds. Increase it for very slow clients. Default: ``1000``.
        """
        if strict and not isinstance(redis_client, StrictRedis):
            raise ValueError("redis_client must be instance of StrictRedis. "
                             "Use strict=False if you know what you're doing.")
        if auto_renewal and expire is None:
            raise ValueError("Expire may not be None when auto_renewal is set")

        self._client = redis_client

        if expire:
            expire = int(expire)
            if expire < 0:
                raise ValueError("A negative expire is not acceptable.")
        else:
            expire = None
        self._expire = expire

        self._signal_expire = signal_expire
        if id is None:
            self._id = b64encode(urandom(18)).decode('ascii')
        elif isinstance(id, binary_type):
            try:
                self._id = id.decode('ascii')
            except UnicodeDecodeError:
                self._id = b64encode(id).decode('ascii')
        elif isinstance(id, text_type):
            self._id = id
        else:
            raise TypeError("Incorrect type for `id`. Must be bytes/str not %s." % type(id))
        self._name = 'lock:' + name
        self._signal = 'lock-signal:' + name
        self._lock_renewal_interval = (float(expire) * 2 / 3
                                       if auto_renewal
                                       else None)
        self._lock_renewal_thread = None

        self.register_scripts(redis_client)

    @classmethod
    def register_scripts(cls, redis_client):
        global reset_all_script
        if reset_all_script is None:
            reset_all_script = redis_client.register_script(RESET_ALL_SCRIPT)
            cls.unlock_script = redis_client.register_script(UNLOCK_SCRIPT)
            cls.extend_script = redis_client.register_script(EXTEND_SCRIPT)
            cls.reset_script = redis_client.register_script(RESET_SCRIPT)
            cls.reset_all_script = redis_client.register_script(RESET_ALL_SCRIPT)

    @property
    def _held(self):
        return self.id == self.get_owner_id()

    def reset(self):
        """
        Forcibly deletes the lock. Use this with care.
        """
        self.reset_script(client=self._client, keys=(self._name, self._signal), args=(self.id, self._signal_expire))

    @property
    def id(self):
        return self._id

    def get_owner_id(self):
        owner_id = self._client.get(self._name)
        if isinstance(owner_id, binary_type):
            owner_id = owner_id.decode('ascii', 'replace')
        return owner_id

    def acquire(self, blocking=True, timeout=None):
        """
        :param blocking:
            Boolean value specifying whether lock should be blocking or not.
        :param timeout:
            An integer value specifying the maximum number of seconds to block.
        """
        logger.debug("Getting %r ...", self._name)

        if self._held:
            raise AlreadyAcquired("Already acquired from this Lock instance.")

        if not blocking and timeout is not None:
            raise TimeoutNotUsable("Timeout cannot be used if blocking=False")

        if timeout:
            timeout = int(timeout)
            if timeout < 0:
                raise InvalidTimeout("Timeout (%d) cannot be less than or equal to 0" % timeout)

            if self._expire and not self._lock_renewal_interval and timeout > self._expire:
                raise TimeoutTooLarge("Timeout (%d) cannot be greater than expire (%d)" % (timeout, self._expire))

        busy = True
        blpop_timeout = timeout or self._expire or 0
        timed_out = False
        while busy:
            busy = not self._client.set(self._name, self._id, nx=True, ex=self._expire)
            if busy:
                if timed_out:
                    return False
                elif blocking:
                    timed_out = not self._client.blpop(self._signal, blpop_timeout) and timeout
                else:
                    logger.debug("Failed to get %r.", self._name)
                    return False
        # 是否应该取刷新过期时间,不是一定要这样做, 这是有风险, 如果当前的进程没有挂,但是一直阻塞,退不出来,就会永远续租
        logger.debug("Got lock for %r.", self._name)
        if self._lock_renewal_interval is not None:
            self._start_lock_renewer()
        return True

    def extend(self, expire=None):
        """Extends expiration time of the lock.
        :param expire:
            New expiration time. If ``None`` - `expire` provided during
            lock initialization will be taken.
        """
        if expire:
            expire = int(expire)
            if expire < 0:
                raise ValueError("A negative expire is not acceptable.")
        elif self._expire is not None:
            expire = self._expire
        else:
            raise TypeError(
                "To extend a lock 'expire' must be provided as an "
                "argument to extend() method or at initialization time."
            )

        error = self.extend_script(client=self._client, keys=(self._name, self._signal), args=(self._id, expire))
        if error == 1:
            raise NotAcquired("Lock %s is not acquired or it already expired." % self._name)
        elif error == 2:
            raise NotExpirable("Lock %s has no assigned expiration time" % self._name)
        elif error:
            raise RuntimeError("Unsupported error code %s from EXTEND script" % error)

    @staticmethod
    def _lock_renewer(lockref, interval, stop):
        """
        Renew the lock key in redis every `interval` seconds for as long
        as `self._lock_renewal_thread.should_exit` is False.
        """
        log = getLogger("%s.lock_refresher" % __name__)
        while not stop.wait(timeout=interval):
            log.debug("Refreshing lock")
            lock = lockref()
            if lock is None:
                log.debug("The lock no longer exists, "
                          "stopping lock refreshing")
                break
            lock.extend(expire=lock._expire)
            del lock
        log.debug("Exit requested, stopping lock refreshing")

    def _start_lock_renewer(self):
        """
        Starts the lock refresher thread.
        """
        if self._lock_renewal_thread is not None:
            raise AlreadyStarted("Lock refresh thread already started")

        logger.debug(
            "Starting thread to refresh lock every %s seconds",
            self._lock_renewal_interval
        )
        self._lock_renewal_stop = threading.Event()
        self._lock_renewal_thread = threading.Thread(
            group=None,
            target=self._lock_renewer,
            kwargs={'lockref': weakref.ref(self),
                    'interval': self._lock_renewal_interval,
                    'stop': self._lock_renewal_stop}
        )
        self._lock_renewal_thread.setDaemon(True)
        self._lock_renewal_thread.start()

    def _stop_lock_renewer(self):
        """
        Stop the lock renewer.
        This signals the renewal thread and waits for its exit.
        """
        if self._lock_renewal_thread is None or not self._lock_renewal_thread.is_alive():
            return
        logger.debug("Signalling the lock refresher to stop")
        self._lock_renewal_stop.set()
        self._lock_renewal_thread.join()
        self._lock_renewal_thread = None
        logger.debug("Lock refresher has stopped")

    def __enter__(self):  # 用来使用with语句
        acquired = self.acquire(blocking=True)
        assert acquired, "Lock wasn't acquired, but blocking=True"
        return self

    def __exit__(self, exc_type=None, exc_value=None, traceback=None):
        self.release()

    def release(self):
        """Releases the lock, that was acquired with the same object.
        .. note::
            If you want to release a lock that you acquired in a different place you have two choices:
            * Use ``Lock("name", id=id_from_other_place).release()``
            * Use ``Lock("name").reset()``
        """
        if self._lock_renewal_thread is not None:
            self._stop_lock_renewer()
        logger.debug("Releasing %r.", self._name)
        error = self.unlock_script(client=self._client, keys=(self._name, self._signal),
                                   args=(self._id, self._signal_expire))
        if error == 1:
            raise NotAcquired("Lock %s is not acquired or it already expired." % self._name)
        elif error:
            raise RuntimeError("Unsupported error code %s from EXTEND script." % error)

    def locked(self):
        """
        Return true if the lock is acquired.
        Checks that lock with same name already exists. This method returns true, even if
        lock have another id.
        """
        return self._client.exists(self._name) == 1


reset_all_script = None


def reset_all(redis_client):
    """
    Forcibly deletes all locks if its remains (like a crash reason). Use this with care.
    :param redis_client:
        An instance of :class:`~StrictRedis`.
    """
    Lock.register_scripts(redis_client)

    reset_all_script(client=redis_client)  # noqa

作者定义了UNLOCK, EXTEND, RESET… 等原子操作的Lua脚本。

如果指定了锁自动刷新,那刷新间隔会设定在超时的2/3时间。

这个库提供的分布式锁很灵活,是否需要超时?是否需要自动刷新?是否要阻塞?都是可选的。没有最好的算法,只有最合适的算法,用户应该根据自己是场景谨慎选择。

使用py-redis-lock的锁实现

import threading

import redis
from datetime import datetime

from peewee import *
from playhouse.shortcuts import ReconnectMixin
from playhouse.pool import PooledMySQLDatabase

from py_redis_lock import Lock as PyLock


class ReconnectMySQLDatabase(ReconnectMixin, PooledMySQLDatabase):
    pass


db = ReconnectMySQLDatabase("inventory", host="192.168.91.1", port=3306, user="root", password="root")


# 删除 - 物理删除和逻辑删除 - 物理删除  -假设你把某个用户数据 - 用户购买记录,用户的收藏记录,用户浏览记录啊
# 通过save方法做了修改如何确保只修改update_time值而不是修改add_time
class BaseModel(Model):
    add_time = DateTimeField(default=datetime.now, verbose_name="添加时间")
    is_deleted = BooleanField(default=False, verbose_name="是否删除")
    update_time = DateTimeField(verbose_name="更新时间", default=datetime.now)

    def save(self, *args, **kwargs):
        # 判断这是一个新添加的数据还是更新的数据
        if self._pk is not None:
            # 这是一个新数据
            self.update_time = datetime.now()
        return super().save(*args, **kwargs)

    @classmethod
    def delete(cls, permanently=False):  # permanently表示是否永久删除
        if permanently:
            return super().delete()
        else:
            return super().update(is_deleted=True)

    def delete_instance(self, permanently=False, recursive=False, delete_nullable=False):
        if permanently:
            return self.delete(permanently).where(self._pk_expr()).execute()
        else:
            self.is_deleted = True
            self.save()

    @classmethod
    def select(cls, *fields):
        return super().select(*fields).where(cls.is_deleted == False)

    class Meta:
        database = db


class Inventory(BaseModel):
    # 商品的库存表
    # stock = PrimaryKeyField(Stock)
    goods = IntegerField(verbose_name="商品id", unique=True)
    stocks = IntegerField(verbose_name="库存数量", default=0)
    version = IntegerField(verbose_name="版本号", default=0)  # 分布式锁的乐观锁


def sell():
    # 多线程下的并发带来的数据不一致的问题
    goods_list = [(1, 10), (2, 20), (3, 30)]
    with db.atomic() as txn:
        # 超卖
        # 续租过期时间 - 看门狗 - java中有一个redisson
        # 如何防止我设置的值被其他的线程给删除掉
        for goods_id, num in goods_list:
            # 查询库存
            
            redis_client = redis.Redis(host="192.168.91.1")
            lock = PyLock(redis_client, f"lock:goods_{goods_id}", auto_renewal=True, expire=15)
            lock.acquire()
            goods_inv = Inventory.get(Inventory.goods == goods_id)
            import time
            time.sleep(20)
            if goods_inv.stocks < num:
                print(f"商品:{goods_id} 库存不足")
                txn.rollback()
                break
            else:
                # 让数据库根据自己当前的值更新数据, 这个语句能不能处理并发的问题
                query = Inventory.update(stocks=Inventory.stocks - num).where(Inventory.goods == goods_id)
                ok = query.execute()
                if ok:
                    print("更新成功")
                else:
                    print("更新失败")
            lock.release()


if __name__ == "__main__":
    t1 = threading.Thread(target=sell)
    t2 = threading.Thread(target=sell)
    t1.start()
    t2.start()

    t1.join()
    t2.join()

Redlock

前面讨论的都是单点的Redis,在集群部署的时候,需要额外考虑一个问题:主从切换

一切顺利的情况:

主从切换的异常情况:

Redlock 的思路:不再部署单一主从集群,而是多个主节点(没有从节点)

比如说我们部署五个主节点,那么加锁过程是类似的,只是要在五个主节点上都加上锁,如果多数(这里是三个)都成功了,那么就认为加锁成功。

示例:

import redis
import time

class Redlock(object):
    def __init__(self, connection_list, retry_times=3, retry_delay=200):
        self.servers = []
        for connection in connection_list:
            self.servers.append(redis.StrictRedis(host=connection["host"], port=connection["port"], db=connection["db"]))
        self.quorum = len(self.servers) // 2 + 1
        self.retry_times = retry_times
        self.retry_delay = retry_delay

    def lock(self, resource, ttl):
        retry = self.retry_times
        while retry > 0:
            n = 0
            start_time = time.time() * 1000
            for server in self.servers:
                if server.set(resource, 1, nx=True, px=ttl):
                    n += 1
            elapsed_time = time.time() * 1000 - start_time
            validity = ttl - elapsed_time - 2
            if n >= self.quorum and validity > 0:
                return validity
            else:
                for server in self.servers:
                    server.delete(resource)
                retry -= 1
                time.sleep(self.retry_delay / 1000)
        return False

    def unlock(self, resource):
        for server in self.servers:
            server.delete(resource)

# Example usage:
redlock = Redlock([{"host": "localhost", "port": 6379, "db": 0}], retry_times=3, retry_delay=200)
validity = redlock.lock("my_resource", 10000)
if validity:
    print("Lock acquired")
    # Do something here...
    redlock.unlock("my_resource")
else:
    print("Failed to acquire lock")

Redlock算法的基本思想是,使用多个Redis实例来协调锁,以确保在任何情况下都不会出现死锁或竞争条件。Redlock算法的实现比较复杂,需要考虑多个因素,例如时钟漂移、网络延迟等。

感兴趣的也可以参考第三方库https://github.com/SPSCommerce/redlock-py/blob/master/redlock/init.py的实现

基于Redis分布式锁的优缺点

  • 优点

    • 性能高
    • 简单
    • redis本身使用很频繁,这样的话不需要我们不需要去额外维护
  • 缺点

    • 依赖第三方组件
    • 单机的redis挂掉的可能性较高
    • 引入redis集群会有一些额外的问题 - redlock

优化

Redis分布式锁是实现分布式锁的一种常用方式,以下是一些可以优化Redis分布式锁的方法:

  1. 使用 RedLock 算法:在 Redis 分布式锁中,为了防止发生死锁,可以使用 RedLock 算法。这种算法是将锁分配到多个 Redis 实例上,通过协同工作来实现分布式锁的目的。当某个 Redis 实例无法正常工作时,其他实例可以继续提供服务,从而避免出现死锁的情况。
  2. 降低 Redis 的网络延迟:在使用 Redis 分布式锁时,网络延迟可能会导致性能问题。可以通过降低 Redis 的网络延迟来提高性能,例如使用本地的 Redis 实例,或者使用高速网络。
  3. 减少 Redis 的操作:在使用 Redis 分布式锁时,应该尽量减少 Redis 的操作次数,以提高性能。例如,可以将锁的持有者信息存储在本地内存中,而不是每次都从 Redis 中获取。
  4. 使用超时时间:在获取 Redis 分布式锁时,应该设置一个超时时间,以避免出现死锁的情况。当一个客户端获取锁后,在规定的时间内未能释放锁,其他客户端可以将其锁定的键值对删除,从而让其他客户端获取锁。
  5. 使用 Lua 脚本:Lua 脚本是 Redis 内置的一种脚本语言,可以用来实现一些复杂的操作,例如分布式锁。通过使用 Lua 脚本,可以将多个 Redis 操作封装成一个原子操作,从而提高性能和安全性。
  6. 使用 Sentinel 高可用方案:Sentinel 是 Redis 的高可用方案之一,它可以监控 Redis 实例的健康状态,并在发生故障时自动切换到备用实例。通过使用 Sentinel,可以提高 Redis 分布式锁的可用性和稳定性。

小结

  • 使用分布式锁,你不能指望框架提供万无一失的方案,自己还是要处理各种异常情况(超时)。
  • 自己写分布式锁,要考虑过期时间,以及要不要续约。
  • 不管要对锁做什么操作,首先要确认这把锁是我们自己的锁
  • 多数时候,与其选择复杂方案,不如直接让业务失败,可能成本还要低一点:有时候直接赔钱,比你部署一大堆节点,招一大堆开发,搞好几个机房还要便宜,而且便宜很多。
  • 选择恰好的方案,而不是完美的方案

参考

https://readthedocs.org/projects/python-redis-lock/downloads/pdf/latest/
https://redis.io/docs/manual/patterns/distributed-locks/
https://python-redis-lock.readthedocs.io/en/latest/
https://juejin.cn/post/7263878008615682104?searchId=202308141723011AA4D66584A48C0F786A
https://levelup.gitconnected.com/implementing-redlock-on-redis-for-distributed-locks-a3cfe60d4ea4

文章来源:https://blog.csdn.net/python_9k/article/details/135474841
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。