RuntimeError: Inference tensors do not track version counter.
发布时间:2023年12月20日
问题:
Testing DataLoader 0:0%||0/78[00:00<?, ?it/s]Failed to collect metadata on function, produced code may be suboptimal. Known situations this can occur are inference mode only compilation involving resize_ or prims (!schema.hasAnyAliasInfo() INTERNAL ASSERT FAILED);if your situation looks different please file a bug to PyTorch.
Traceback (most recent call last):
File "/home/buty/.local/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 1674,in aot_wrapper_dedupe
fw_metadata, _out = run_functionalized_fw_and_collect_metadata(
File "/home/buty/.local/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 606,in inner
flat_f_outs = f(*flat_f_args)
File "/home/buty/.local/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2776,in functional_call
out = Interpreter(mod).run(*args[params_len:],**kwargs)
File "/home/buty/.local/lib/python3.8/site-packages/torch/fx/interpreter.py", line 136,in run
self.env[node]= self.run_node(node)
File "/home/buty/.local/lib/python3.8/site-packages/torch/fx/interpreter.py", line 177,in run_node
returngetattr(self, n.op)(n.target, args, kwargs)
File "/home/buty/.local/lib/python3.8/site-packages/torch/fx/interpreter.py", line 294,in call_module
return submod(*args,**kwargs)
File "/home/buty/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501,in _call_impl
return forward_call(*args,**kwargs)
File "/home/buty/.local/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 114,in forward
return F.linear(input, self.weight, self.bias)
File "/home/buty/.local/lib/python3.8/site-packages/torch/_inductor/overrides.py", line 38,in __torch_function__
return func(*args,**kwargs)
RuntimeError: Inference tensors do not track version counter.
While executing %self_model_fc :[#users=3] = call_module[target=self_model_fc](args = (%flatten,), kwargs = {})
Original traceback:
File "/home/buty/.local/lib/python3.8/site-packages/torchvision/models/resnet.py", line 280,in _forward_impl
x = self.fc(x)| File "/home/buty/.local/lib/python3.8/site-packages/torchvision/models/resnet.py", line 285,in forward
return self._forward_impl(x)| File "/opt/extend/buty/work/ocr/YuzuMarker.FontDetection/detector/model.py", line 53,in forward
X = self.model(X)
Traceback (most recent call last):
File "/home/buty/.local/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 670,in call_user_compiler
compiled_fn = compiler_fn(gm, self.fake_example_inputs())
File "/home/buty/.local/lib/python3.8/site-packages/torch/_dynamo/debug_utils.py", line 1055,in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs)
File "/home/buty/.local/lib/python3.8/site-packages/torch/__init__.py", line 1390,in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
File "/home/buty/.local/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 455,in compile_fx
return aot_autograd(
File "/home/buty/.local/lib/python3.8/site-packages/torch/_dynamo/backends/common.py", line 48,in compiler_fn
cg = aot_module_simplified(gm, example_inputs,**kwargs)
File "/home/buty/.local/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2805,in aot_module_simplified
compiled_fn = create_aot_dispatcher_function(
File "/home/buty/.local/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 163,in time_wrapper
r = func(*args,**kwargs)
File "/home/buty/.local/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2498,in create_aot_dispatcher_function
compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
File "/home/buty/.local/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 1802,in aot_wrapper_dedupe
compiled_fn = compiler_fn(wrapped_flat_fn, deduped_flat_args, aot_config)
File "/home/buty/.local/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 1278,in aot_dispatch_base
_fw_metadata, _out = run_functionalized_fw_and_collect_metadata(
File "/home/buty/.local/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 606,in inner
flat_f_outs = f(*flat_f_args)
File "/home/buty/.local/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 1800,in wrapped_flat_fn
return flat_fn(*add_dupe_args(args))
File "/home/buty/.local/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2776,in functional_call
out = Interpreter(mod).run(*args[params_len:],**kwargs)
File "/home/buty/.local/lib/python3.8/site-packages/torch/fx/interpreter.py", line 136,in run
self.env[node]= self.run_node(node)
File "/home/buty/.local/lib/python3.8/site-packages/torch/fx/interpreter.py", line 177,in run_node
returngetattr(self, n.op)(n.target, args, kwargs)
File "/home/buty/.local/lib/python3.8/site-packages/torch/fx/interpreter.py", line 294,in call_module
return submod(*args,**kwargs)
File "/home/buty/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501,in _call_impl
return forward_call(*args,**kwargs)
File "/home/buty/.local/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 114,in forward
return F.linear(input, self.weight, self.bias)
File "/home/buty/.local/lib/python3.8/site-packages/torch/_inductor/overrides.py", line 38,in __torch_function__
return func(*args,**kwargs)
RuntimeError: Inference tensors do not track version counter.
While executing %self_model_fc :[#users=3] = call_module[target=self_model_fc](args = (%flatten,), kwargs = {})
Original traceback:
File "/home/buty/.local/lib/python3.8/site-packages/torchvision/models/resnet.py", line 280,in _forward_impl
x = self.fc(x)| File "/home/buty/.local/lib/python3.8/site-packages/torchvision/models/resnet.py", line 285,in forward
return self._forward_impl(x)| File "/opt/extend/buty/work/ocr/YuzuMarker.FontDetection/detector/model.py", line 53,in forward
X = self.model(X)
The above exception was the direct cause of the following exception:
主要是因为使用了torch.compile()
if torch.__version__ >="2.0"and os.name =="posix":
model = torch.compile(model)
。。。。。。
#trainer.fit(detector, datamodule=data_module, ckpt_path=args.checkpoint)
trainer.test(detector, datamodule=data_module)
解决办法:把model = torch.compile(model)去掉
#if torch.__version__ >= "2.0" and os.name == "posix":# model = torch.compile(model)
。。。。。。
#trainer.fit(detector, datamodule=data_module, ckpt_path=args.checkpoint)
trainer.test(detector, datamodule=data_module)