diff --git a/tinygrad/device.py b/tinygrad/device.py index b56f1e6691..3ebebb07a7 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -44,6 +44,7 @@ class _Device: return device except StopIteration as exc: raise RuntimeError("no usable devices") from exc Device = _Device() +atexit.register(lambda: [Device[dn].finalize() for dn in Device._opened_devices]) # **************** Profile **************** @@ -300,6 +301,11 @@ class Compiled: Called at the end of profiling to allow the device to finalize any profiling. """ # override this in your device implementation + def finalize(self): + """ + Called at the end of process lifetime to allow the device to finalize. + """ + # override this in your device implementation # TODO: move this to each Device def is_dtype_supported(dtype:DType, device:Optional[str]=None) -> bool: diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index daba3e35b7..ba21fa7603 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import Any, cast -import os, ctypes, ctypes.util, functools, mmap, errno, array, contextlib, sys, select, atexit +import os, ctypes, ctypes.util, functools, mmap, errno, array, contextlib, sys, select assert sys.platform != 'win32' from dataclasses import dataclass from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, CLikeArgsState, HCQSignal, HCQProgram, HWInterface @@ -601,8 +601,6 @@ class AMDDevice(HCQCompiled): self.max_private_segment_size = 0 self._ensure_has_local_memory(128) # set default scratch size to 128 bytes per thread - atexit.register(self.device_fini) - def create_queue(self, queue_type, ring_size, ctx_save_restore_size=0, eop_buffer_size=0, ctl_stack_size=0, debug_memory_size=0): ring = self.dev_iface.alloc(ring_size, uncached=True, cpu_access=True) gart = self.dev_iface.alloc(0x1000, uncached=True, cpu_access=True) @@ -632,6 +630,6 @@ class AMDDevice(HCQCompiled): def on_device_hang(self): self.dev_iface.on_device_hang() - def device_fini(self): + def finalize(self): self.synchronize() if hasattr(self.dev_iface, 'device_fini'): self.dev_iface.device_fini() diff --git a/tinygrad/runtime/support/am/amdev.py b/tinygrad/runtime/support/am/amdev.py index ee117d9495..7ce9deae09 100644 --- a/tinygrad/runtime/support/am/amdev.py +++ b/tinygrad/runtime/support/am/amdev.py @@ -300,6 +300,7 @@ class AMDev: if DEBUG >= 2: print(f"am {self.devfmt}: boot done") def fini(self): + if DEBUG >= 2: print(f"am {self.devfmt}: Finalizing") for ip in [self.sdma, self.gfx]: ip.fini() self.smu.set_clocks(level=0) self.ih.interrupt_handler()