目录
书接上回《【JUC进阶】13. InheritableThreadLocal》,提到了InheritableThreadLocal虽然能进行父子线程的值传递,但是如果在线程池中,就无法达到预期的效果了。为了更好的解决该问题,TransmittableThreadLocal诞生了。
TransmittableThreadLocal?是Alibaba开源的、用于解决?“在使用线程池等会缓存线程的组件情况下传递ThreadLocal”?问题的 InheritableThreadLocal 扩展。既然是扩展,那么自然具备InheritableThreadLocal不同线程间值传递的能力。但是他也是专门为了解决InheritableThreadLocal在线程池中出现的问题的。
官网地址:https://github.com/alibaba/transmittable-thread-local
我们拿《【JUC进阶】13. InheritableThreadLocal》文中最后的demo进行改造。这里需要配合TtlExecutors一起使用。这里先讲述使用方法,具体为什么下面细说。
首先,我们需要添加依赖:
<!-- https://mvnrepository.com/artifact/com.alibaba/transmittable-thread-local -->
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>transmittable-thread-local</artifactId>
<version>2.14.2</version>
</dependency>
其次,ThreadLocal的实现改为TransmittableThreadLocal。
static ThreadLocal<String> threadLocal = new TransmittableThreadLocal<>();
最后创建线程池的时候,使用TTL装饰器:
static ExecutorService executorService = TtlExecutors.getTtlExecutorService(Executors.newSingleThreadExecutor());
完整代码如下:
// threadlocal改为TransmittableThreadLocal
static ThreadLocal<String> threadLocal = new TransmittableThreadLocal<>();
// 线程池添加TtlExecutors
static ExecutorService executorService = TtlExecutors.getTtlExecutorService(Executors.newSingleThreadExecutor());
public static void main(String[] args) throws InterruptedException {
//threadLocal.set("我是主线程的threadlocal变量,变量值为:000000");
// 线程池执行子线程
executorService.submit(() -> {
System.out.println("-----> 子线程" + Thread.currentThread() + " <----- 获取threadlocal变量:" + threadLocal.get());
});
// 主线程睡眠3s,模拟运行
Thread.sleep(3000);
// 将变量修改为11111,在InheritableThreadLocal中修改是无效的
threadLocal.set("我是主线程的threadlocal变量,变量值为:11111");
// 这里线程池重新执行线程任务
executorService.submit(() -> {
System.out.println("-----> 子线程" + Thread.currentThread() + " <----- 获取threadlocal变量:" + threadLocal.get());
});
// 线程池关闭
executorService.shutdown();
}
执行看下效果:
已经成功获取到threadlocal变量。
该方式也解决了因为线程被重复利用,而threadlocal重新赋值失效的问题。
首先可以看到TransmittableThreadLocal继承InheritableThreadLocal,同时实现了TtlCopier接口。TtlCopier接口只提供了一个方法copy()。看到这里,可能有人大概猜出来他的实现原理了,既然实现了copy()方法,那么大概率是将父线程的变量复制一份存起来,接着找个地方存起来,然后找个适当的时机再还回去。没错,其实就是这样。
public class TransmittableThreadLocal<T> extends InheritableThreadLocal<T> implements TtlCopier<T> {
}
知道了TransmittableThreadLocal类的定义之后,我们再来看一个重要的属性holder:
// Note about the holder:
// 1. holder self is a InheritableThreadLocal(a *ThreadLocal*).
// 2. The type of value in the holder is WeakHashMap<TransmittableThreadLocal<Object>, ?>.
// 2.1 but the WeakHashMap is used as a *Set*:
// the value of WeakHashMap is *always* null, and never used.
// 2.2 WeakHashMap support *null* value.
private static final InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>> holder =
new InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>>() {
@Override
protected WeakHashMap<TransmittableThreadLocal<Object>, ?> initialValue() {
return new WeakHashMap<>();
}
@Override
protected WeakHashMap<TransmittableThreadLocal<Object>, ?> childValue(WeakHashMap<TransmittableThreadLocal<Object>, ?> parentValue) {
return new WeakHashMap<>(parentValue);
}
};
这里存放的是一个全局的WeakMap(同ThreadLocal一样,weakMap也是为了解决内存泄漏的问题),里面存放了TransmittableThreadLocal对象并且重写了initialValue和childValue方法,尤其是childValue,可以看到在即将异步时父线程的属性是直接作为初始化值赋值给子线程的本地变量对象。引入holder变量后,也就不必对外暴露Thread中的?inheritableThreadLocals,保持ThreadLocal.ThreadLocalMap的封装性。
而TransmittableThreadLocal中的get()和set()方法,都是从该holder中获取或添加该map。
重点来了,前面不是提到了需要借助于TtlExecutors.getTtlExecutorService()包装线程池才能达到效果吗?我们来看看这里做了什么事。
我们从TtlExecutors.getTtlExecutorService()方法跟进可以发现一个线程池的ttl包装类ExecutorServiceTtlWrapper。其中包含了我们执行线程的方法submit()和execute()。我们进入submit()方法:
@NonNull
@Override
public <T> Future<T> submit(@NonNull Callable<T> task) {
return executorService.submit(TtlCallable.get(task, false, idempotent));
}
可以发现在线程池进行任务执行时,对我们提交的任务进行了一层预处理,TtlCallable.get()。TtlCallable也是Callable的装饰类,同样还有TtlRunnable,也是同样道理。我们跟进该方法偷瞄一眼:
@Nullable
@Contract(value = "null, _, _ -> null; !null, _, _ -> !null", pure = true)
public static <T> TtlCallable<T> get(@Nullable Callable<T> callable, boolean releaseTtlValueReferenceAfterCall, boolean idempotent) {
if (callable == null) return null;
if (callable instanceof TtlEnhanced) {
// avoid redundant decoration, and ensure idempotency
if (idempotent) return (TtlCallable<T>) callable;
else throw new IllegalStateException("Already TtlCallable!");
}
return new TtlCallable<>(callable, releaseTtlValueReferenceAfterCall);
}
上面判断下当前线程的类型是否已经是TtlEnhanced,如果是直接返回,否则创建一个TtlCallable。接着进入new TtlCallable()方法:
private TtlCallable(@NonNull Callable<V> callable, boolean releaseTtlValueReferenceAfterCall) {
this.capturedRef = new AtomicReference<>(capture());
this.callable = callable;
this.releaseTtlValueReferenceAfterCall = releaseTtlValueReferenceAfterCall;
}
可以看到在初始化线程的时候,调用了一个capture()方法,并将该方法得到的值存放在capturedRef中。没错,这里就是上面我们提到的将父线程的本地变量复制一份快照,存放起来。跟进capture():
@NonNull
public static Object capture() {
final HashMap<Transmittee<Object, Object>, Object> transmittee2Value = newHashMap(transmitteeSet.size());
for (Transmittee<Object, Object> transmittee : transmitteeSet) {
try {
transmittee2Value.put(transmittee, transmittee.capture());
} catch (Throwable t) {
if (logger.isLoggable(Level.WARNING)) {
logger.log(Level.WARNING, "exception when Transmitter.capture for transmittee " + transmittee +
"(class " + transmittee.getClass().getName() + "), just ignored; cause: " + t, t);
}
}
}
return new Snapshot(transmittee2Value);
}
这里的transmitteeSet是一个存放Transmitteedede 集合,在初始化中会将我们 前面提到的holder注册进去:
private static final Set<Transmittee<Object, Object>> transmitteeSet = new CopyOnWriteArraySet<>();
static {
registerTransmittee(ttlTransmittee);
registerTransmittee(threadLocalTransmittee);
}
@SuppressWarnings("unchecked")
public static <C, B> boolean registerTransmittee(@NonNull Transmittee<C, B> transmittee) {
return transmitteeSet.add((Transmittee<Object, Object>) transmittee);
}
跟进transmittee.capture()方法,该方法由静态内部类Transmitter实现并重写,com.alibaba.ttl.TransmittableThreadLocal.Transmitter.Transmittee#capture
private static final Transmittee<HashMap<TransmittableThreadLocal<Object>, Object>, HashMap<TransmittableThreadLocal<Object>, Object>> ttlTransmittee =
new Transmittee<HashMap<TransmittableThreadLocal<Object>, Object>, HashMap<TransmittableThreadLocal<Object>, Object>>() {
@NonNull
@Override
public HashMap<TransmittableThreadLocal<Object>, Object> capture() {
final HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = newHashMap(holder.get().size());
for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
ttl2Value.put(threadLocal, threadLocal.copyValue());
}
return ttl2Value;
}
}
transmittee.capture()扫描holder里目前存放的k-v里的key,就是需要传给子线程的TTL对象,其中调用的threadLocal.copyValue()便是前面看到的TtlCopier接口提供的方法。
看到这里已经大致符合我们前面的猜想,将变量复制一份存起来。那么不出意外接下来应该就是要找个适当的机会还回去。我们接着看。
接下来我们看真正执行线程的时候,也就是call()方法。由于前面线程被TtlCallable包装过,以为这里的call()方法肯定是TtlCallable.call():
@Override
@SuppressFBWarnings("THROWS_METHOD_THROWS_CLAUSE_BASIC_EXCEPTION")
public V call() throws Exception {
// 获取由之前捕获到的父线程变量集
final Object captured = capturedRef.get();
if (captured == null || releaseTtlValueReferenceAfterCall && !capturedRef.compareAndSet(captured, null)) {
throw new IllegalStateException("TTL value reference is released after call!");
}
// 这里的backup是当前线程原有的变量,这里进行备份,等线程执行完毕后,会将该变量进行恢复
final Object backup = replay(captured);
try {
// 任务执行
return callable.call();
} finally {
// 恢复上述提到的backup原有变量
restore(backup);
}
}
果然,在执行线程时,先获取之前存放起来的变量。然后调用replay():
@NonNull
public static Object replay(@NonNull Object captured) {
final Snapshot capturedSnapshot = (Snapshot) captured;
final HashMap<Transmittee<Object, Object>, Object> transmittee2Value = newHashMap(capturedSnapshot.transmittee2Value.size());
for (Map.Entry<Transmittee<Object, Object>, Object> entry : capturedSnapshot.transmittee2Value.entrySet()) {
Transmittee<Object, Object> transmittee = entry.getKey();
try {
Object transmitteeCaptured = entry.getValue();
transmittee2Value.put(transmittee, transmittee.replay(transmitteeCaptured));
} catch (Throwable t) {
if (logger.isLoggable(Level.WARNING)) {
logger.log(Level.WARNING, "exception when Transmitter.replay for transmittee " + transmittee +
"(class " + transmittee.getClass().getName() + "), just ignored; cause: " + t, t);
}
}
}
return new Snapshot(transmittee2Value);
}
继续跟进transmittee.replay(transmitteeCaptured):
@NonNull
@Override
public HashMap<TransmittableThreadLocal<Object>, Object> replay(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> captured) {
final HashMap<TransmittableThreadLocal<Object>, Object> backup = newHashMap(holder.get().size());
for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
TransmittableThreadLocal<Object> threadLocal = iterator.next();
// 这里便是所有原生的本地变量都暂时存储在backup里,用于之后恢复用
backup.put(threadLocal, threadLocal.get());
// clear the TTL values that is not in captured
// avoid the extra TTL values after replay when run task
// 这里检查如果当前变量不存在于捕获到的线程变量,那么就将他清除掉,对应线程的本地变量也清理掉
// 为什么要清除?因为从使用这个子线程做异步那里,捕获到的本地变量并不包含原生的变量,当前线程
// 在做任务时的首要目标,是将父线程里的变量完全传递给任务,如果不清除这个子线程原生的本地变量,
// 意味着很可能会影响到任务里取值的准确性。这也就是为什么上面需要做备份的原因。
if (!captured.containsKey(threadLocal)) {
iterator.remove();
threadLocal.superRemove();
}
}
// set TTL values to captured
setTtlValuesTo(captured);
// call beforeExecute callback
doExecuteCallback(true);
return backup;
}
继续跟进setTtlValuesTo(captured),这里就是把父线程本地变量赋值给当前线程了:
private static void setTtlValuesTo(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> ttlValues) {
for (Map.Entry<TransmittableThreadLocal<Object>, Object> entry : ttlValues.entrySet()) {
TransmittableThreadLocal<Object> threadLocal = entry.getKey();
threadLocal.set(entry.getValue());
}
}
到这里基本的实现原理也差不多了,基本和我们前面猜想的一致。但是这里还少了前面提到的backup变量如何恢复的步骤,既然到这里了,一起看一下,跟进restore(backup):
public static void restore(@NonNull Object backup) {
for (Map.Entry<Transmittee<Object, Object>, Object> entry : ((Snapshot) backup).transmittee2Value.entrySet()) {
Transmittee<Object, Object> transmittee = entry.getKey();
try {
Object transmitteeBackup = entry.getValue();
transmittee.restore(transmitteeBackup);
} catch (Throwable t) {
if (logger.isLoggable(Level.WARNING)) {
logger.log(Level.WARNING, "exception when Transmitter.restore for transmittee " + transmittee +
"(class " + transmittee.getClass().getName() + "), just ignored; cause: " + t, t);
}
}
}
}
继续看transmittee.restore(transmitteeBackup):
@Override
public void restore(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> backup) {
// call afterExecute callback
doExecuteCallback(false);
for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
TransmittableThreadLocal<Object> threadLocal = iterator.next();
// clear the TTL values that is not in backup
// avoid the extra TTL values after restore
if (!backup.containsKey(threadLocal)) {
iterator.remove();
threadLocal.superRemove();
}
}
// restore TTL values
setTtlValuesTo(backup);
}
与replay类似,只是重复进行了将backup赋给当前线程的步骤。到此基本结束。附上官网的时序图帮助理解:
所以总结下来,TransmittableThreadLocal的实现原理主要就是依赖于TtlRunnable或TtlCallable装饰类的预处理方法,TtlExecutors是将普通线程转换成Ttl包装的线程,而ttl包装的线程会进行本地变量的预处理,也就是capture()拷贝一份快照到内存中,然后通过replay方法将父线程的变量赋值给当前线程。