TransmittableThreadLocal详解,源码分析,一文带你掌握核心逻辑

发布时间:2024年01月05日

holder 变量是一个InheritableThreadLocal, 他是一个map但是一直都是当作Set在用,value一直是空

The value of holder is type WeakHashMap<TransmittableThreadLocal, ?>, but it is used as Set (aka. do NOT use about value, always null).

每次使用holder变量都会带着.get() ,意味着每次获取到的WeakHashMap都是线程自己的,这是我之前一直不理解的点,记住,每次使用holder都会带上.get(),而不是真正全局使用一个WeakHashMap

源码分析

TransmittableThreadLocal

先来分析一下set()方法

  1. 首先会将当前的value设置到Thread中的ThreadLocalMap

  2. 然后将当前的TransmittableThreadLocal放入当前线程的map中(给未来打快照使用)

public final void set(T value) {
  // 首先先使用ThreadLocal的特性,将当前的value设置到Thread中的ThreadLocalMap,保证ThreadLocal的特性被保留
  super.set(value);
  // may set null to remove value
  // 如果是value是null,就从map中删除
  if (null == value) removeValue();
  // 这一步本质就是将当前的TransmittableThreadLocal放入当前线程的map中,以保存父线程的使用过的TransmittableThreadLocal
  else addValue();
}

private void addValue() {
  if (!holder.get().containsKey(this)) {
    // 将当前的TransmittableThreadLocal放入当前线程的map中
    // 这里value一直是空
    holder.get().put((TransmittableThreadLocal<Object>) this, null); // WeakHashMap supports null value.
  }
}

get()方法没有什么特殊的大家看看就行

public final T get() {
  T value = super.get();
  if (null != value) addValue();
  return value;
}

TtlRunnable

重要!!!

我认为整个TransmittableThreadLocal的核心就是使用装饰模式,将整个Runnable 包装了一层,实现了当线程复用的情况也可以将父线程继承到子线程的能力

首先在使用TtlRunnable.get(runnable),会将Runnable包装一层,此时的调用方就是父线程,在方法中会调用capture(),获取当前父线程的快照

打快照
public static TtlRunnable get(@Nullable Runnable runnable, boolean releaseTtlValueReferenceAfterRun, boolean idempotent) {
  if (null == runnable) return null;

  if (runnable instanceof TtlEnhanced) {
    // avoid redundant decoration, and ensure idempotency
    if (idempotent) return (TtlRunnable) runnable;
    else throw new IllegalStateException("Already TtlRunnable!");
  }
  return new TtlRunnable(runnable, releaseTtlValueReferenceAfterRun);
}

// 构造函数
private TtlRunnable(@NonNull Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {
  // capture()此方法是核心
  this.capturedRef = new AtomicReference<Object>(capture());
  this.runnable = runnable;
  this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun;
}

这里主要是将父线程所有的threadLocal打一个快照

第一个是TransmittableThreadLocal

第二个是threadLocalHolder,这个threadLocalHolder的作用是对于在项目中使用了ThreadLocal,但是却无法替换为TransmittableThreadLocal的情况,可以使用Transmitter提供的注册方法,将项目中的threadLocal注册到它的threadLocalHolder中,后面进行capture等操作时holder和threadLocalHolder都会进行处理使用

public static Object capture() {
  return new Snapshot(captureTtlValues(), captureThreadLocalValues());
}

private static WeakHashMap<TransmittableThreadLocal<Object>, Object> captureTtlValues() {
  
  WeakHashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
  // 核心:获取之前保存的当前线程使用过的TransmittableThreadLocal,然后将其放在一个map中,key为TransmittableThreadLocal,value为具体的值
  // 这样就形成了在父线程调用TtlRunnable.get(runnable)父线程使用TransmittableThreadLocal的快照
  for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
    ttl2Value.put(threadLocal, threadLocal.copyValue());
  }
  return ttl2Value;
}
// 这个是处理threadLocalHolder
private static WeakHashMap<ThreadLocal<Object>, Object> captureThreadLocalValues() {
  final WeakHashMap<ThreadLocal<Object>, Object> threadLocal2Value = new WeakHashMap<ThreadLocal<Object>, Object>();
  for (Map.Entry<ThreadLocal<Object>, TtlCopier<Object>> entry : threadLocalHolder.entrySet()) {
    final ThreadLocal<Object> threadLocal = entry.getKey();
    final TtlCopier<Object> copier = entry.getValue();

    threadLocal2Value.put(threadLocal, copier.copy(threadLocal.get()));
  }
  return threadLocal2Value;
}
run方法包装

接下来就是将Run方法包装了一层,注意调用run方法的一定是子线程

  1. 获取当前之前创建的父线程ThreadLocal快照

  2. 重放快照到子线程的Thread中

  3. 执行run方法

  4. 恢复子线程的快照

@Override
public void run() {
  // 获取当前之前创建的父线程ThreadLocal快照
  // 这里快照不应该为空
  Object captured = capturedRef.get();
  if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
    throw new IllegalStateException("TTL value reference is released after run!");
  }

  Object backup = replay(captured);
  try {
    runnable.run();
  } finally {
    restore(backup);
  }
}
重放快照
// 将父线程的ThreadLocal回放到子线程中
@NonNull
public static Object replay(@NonNull Object captured) {
  final Snapshot capturedSnapshot = (Snapshot) captured;
  return new Snapshot(replayTtlValues(capturedSnapshot.ttl2Value), replayThreadLocalValues(capturedSnapshot.threadLocal2Value));
}

// 回放快照
// 备份子线程threadLocal
@NonNull
private static WeakHashMap<TransmittableThreadLocal<Object>, Object> replayTtlValues(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> captured) {
  WeakHashMap<TransmittableThreadLocal<Object>, Object> backup = new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
	// 将子线程自带的threadLocal给备份起来,其实也是打个快照而已
  for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
    TransmittableThreadLocal<Object> threadLocal = iterator.next();

    // backup
    // 将子线程自带的threadLocal给备份起来,其实也是打个快照而已
    backup.put(threadLocal, threadLocal.get());

    // 如果父线程快照中不存在当前ThreadLocal 就删掉这个threadLocal
    // 因为这一步就是为了把父线程的threadLocal放进子线程中
    // clear the TTL values that is not in captured
    // avoid the extra TTL values after replay when run task
    if (!captured.containsKey(threadLocal)) {
      iterator.remove();
      threadLocal.superRemove();
    }
  }

  // set TTL values to captured
  // 将父线程的threadLocal快照放到子线程中
  setTtlValuesTo(captured);

  // call beforeExecute callback
  doExecuteCallback(true);

  return backup;
}

private static void setTtlValuesTo(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> ttlValues) {
  // 将父线程的threadLocal快照放到子线程中
  for (Map.Entry<TransmittableThreadLocal<Object>, Object> entry : ttlValues.entrySet()) {
    TransmittableThreadLocal<Object> threadLocal = entry.getKey();
    threadLocal.set(entry.getValue());
  }
}

private static WeakHashMap<ThreadLocal<Object>, Object> replayThreadLocalValues(@NonNull WeakHashMap<ThreadLocal<Object>, Object> captured) {
  final WeakHashMap<ThreadLocal<Object>, Object> backup = new WeakHashMap<ThreadLocal<Object>, Object>();

  for (Map.Entry<ThreadLocal<Object>, Object> entry : captured.entrySet()) {
    final ThreadLocal<Object> threadLocal = entry.getKey();
    backup.put(threadLocal, threadLocal.get());

    final Object value = entry.getValue();
    if (value == threadLocalClearMark) threadLocal.remove();
    else threadLocal.set(value);
  }

  return backup;
}
restore

恢复子线程的threadLocal现场

仔细看就会发现是replay的反向操作

public static void restore(@NonNull Object backup) {
    final Snapshot backupSnapshot = (Snapshot) backup;
    restoreTtlValues(backupSnapshot.ttl2Value);
    restoreThreadLocalValues(backupSnapshot.threadLocal2Value);
}
private static void restoreTtlValues(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> backup) {
  // call afterExecute callback
  doExecuteCallback(false);
  // 查询子线程使用的TransmittableThreadLocal, 然后遍历它
  for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
    TransmittableThreadLocal<Object> threadLocal = iterator.next();

    // clear the TTL values that is not in backup
    // avoid the extra TTL values after restore
    // 原来不存在就删掉这个threadLocal
    if (!backup.containsKey(threadLocal)) {
      iterator.remove();
      threadLocal.superRemove();
    }
  }

  // restore TTL values
  setTtlValuesTo(backup);
}
private static void setTtlValuesTo(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> ttlValues) {
  for (Map.Entry<TransmittableThreadLocal<Object>, Object> entry : ttlValues.entrySet()) {
    // 遍历一下之前子线程的快照是存在的,把它恢复了
    TransmittableThreadLocal<Object> threadLocal = entry.getKey();
    threadLocal.set(entry.getValue());
  }
}
private static void restoreThreadLocalValues(@NonNull WeakHashMap<ThreadLocal<Object>, Object> backup) {
    for (Map.Entry<ThreadLocal<Object>, Object> entry : backup.entrySet()) {
        final ThreadLocal<Object> threadLocal = entry.getKey();
        threadLocal.set(entry.getValue());
    }
}

总结

到这里分析就结束了

整体我认为这个设计的还是很巧妙的,解决了InheritableThreadLocal在线程复用(线程池的情况无法使用的问题)

  1. 首先使用了holder这样一个ThreadLocal,记录了每一个线程使用了哪些threadLocal,到时候可以直接将这个线程所有的thread以及value遍历出来

  2. 只用TtlRunnable把Runnable包装了一层,在调用.get时就把父线程打了个快照

  3. 把Runnable的run方法包装了一层,让线程开始执行之前回放父线程的threadLocal,执行结束后恢复子线程原来就有的threadLocal

    如果大家喜欢这篇文章的话,可以点赞收藏一下,这是对我最大的支持

    联系方式: xianchaolin@126.com

文章来源:https://blog.csdn.net/weixin_42293662/article/details/135415226
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。