From 1fdcb13bfbf63f037f46bdd4c4c6e3227bf24c8a Mon Sep 17 00:00:00 2001 From: Denys Melnyk Date: Sat, 25 Apr 2026 09:04:55 +0200 Subject: [PATCH] webgpu: fix weight lookup in export_model after compile_net key change (#15919) * fix lookup site in export_model_webgpu after refactoring webgpu (sd): fix export_model weight lookup after compile_net changes fix lookup site in export_model_webgpu after refactoring * add regression test --- examples/webgpu/stable_diffusion/compile.py | 2 +- extra/export_model.py | 2 +- test/testextra/test_export_model.py | 12 ++++++++++++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/examples/webgpu/stable_diffusion/compile.py b/examples/webgpu/stable_diffusion/compile.py index fd926a988f..cfa2689705 100644 --- a/examples/webgpu/stable_diffusion/compile.py +++ b/examples/webgpu/stable_diffusion/compile.py @@ -114,7 +114,7 @@ if __name__ == "__main__": linear, output_bufs = jit_model(step, *step.input) functions, statements, bufs, _ = compile_net(linear, output_bufs) state = get_state_dict(model) - weights = {id(x.uop.base.realized): name for name, x in state.items()} + weights = {(id(b), b.offset, b.size, b.dtype): name for name, x in state.items() if (b:=x.uop.base.realized) is not None} kernel_code = '\n\n'.join([f"const {key} = `{fixup_code(code, key)}`;" for key, code in functions.items()]) kernel_names = ', '.join([name for (name, _, _, _) in statements]) input_names = [f"input{i}" for i in range(len(step.input))] diff --git a/extra/export_model.py b/extra/export_model.py index a9a56a48a9..9d0125de34 100644 --- a/extra/export_model.py +++ b/extra/export_model.py @@ -244,7 +244,7 @@ def export_model(model, target:str, *inputs, model_name: Optional[str] = "model" with Context(JIT=2, CPU_COUNT=1): linear, output_bufs = jit_model(model, *inputs) functions, statements, bufs, bufs_to_save = compile_net(linear, output_bufs) state = get_state_dict(model) - weight_names = {id(x.uop.base.realized): name for name, x in state.items()} + weight_names = {(id(b), b.offset, b.size, b.dtype): name for name, x in state.items() if (b:=x.uop.base.realized) is not None} input_names = [f"input{i}" for i in range(len(inputs))] output_names = [f"output{i}" for i in range(len(output_bufs))] diff --git a/test/testextra/test_export_model.py b/test/testextra/test_export_model.py index dc246c526b..8b87ba9020 100644 --- a/test/testextra/test_export_model.py +++ b/test/testextra/test_export_model.py @@ -2,6 +2,8 @@ import unittest from extra.export_model import export_model, EXPORT_SUPPORTED_DEVICE from tinygrad.tensor import Tensor from tinygrad.device import Device +from tinygrad.nn import Linear +from tinygrad.nn.state import get_state_dict from tinygrad import dtypes import json @@ -66,5 +68,15 @@ class TextModelExportWebGPU(unittest.TestCase): self.assertIn(f"const resultBuffer{i} = new {expected_arr_prefix}Array(gpuReadBuffer{i}.size/{dt.itemsize});", prg) self.assertIn(f"resultBuffer{i}.set(new {expected_arr_prefix}Array(gpuReadBuffer{i}.getMappedRange()));", prg) + def test_weights_bound_to_safetensor(self): + # regression test: every weight ended up as createEmptyBuf (zero-init) instead of createWeightBuf + class MyModel: + def __init__(self): self.fc1, self.fc2 = Linear(4, 8), Linear(8, 2) + def forward(self, x): return self.fc2(self.fc1(x).relu()) + model = MyModel() + for t in get_state_dict(model).values(): t.realize() + prg, _, _, _ = export_model(model, "webgpu", Tensor.randn(1, 4)) + self.assertEqual(prg.count("createWeightBuf("), len(get_state_dict(model))) + if __name__ == '__main__': unittest.main()