webgpu: fix weight lookup in export_model after compile_net key change (#15919)

* fix lookup site in export_model_webgpu after refactoring

webgpu (sd): fix export_model weight lookup after compile_net changes

fix lookup site in export_model_webgpu after refactoring

* add regression test
This commit is contained in:
Denys Melnyk
2026-04-25 09:04:55 +02:00
committed by GitHub
parent 8b2826ef16
commit 1fdcb13bfb
3 changed files with 14 additions and 2 deletions

View File

@@ -114,7 +114,7 @@ if __name__ == "__main__":
linear, output_bufs = jit_model(step, *step.input)
functions, statements, bufs, _ = compile_net(linear, output_bufs)
state = get_state_dict(model)
weights = {id(x.uop.base.realized): name for name, x in state.items()}
weights = {(id(b), b.offset, b.size, b.dtype): name for name, x in state.items() if (b:=x.uop.base.realized) is not None}
kernel_code = '\n\n'.join([f"const {key} = `{fixup_code(code, key)}`;" for key, code in functions.items()])
kernel_names = ', '.join([name for (name, _, _, _) in statements])
input_names = [f"input{i}" for i in range(len(step.input))]

View File

@@ -244,7 +244,7 @@ def export_model(model, target:str, *inputs, model_name: Optional[str] = "model"
with Context(JIT=2, CPU_COUNT=1): linear, output_bufs = jit_model(model, *inputs)
functions, statements, bufs, bufs_to_save = compile_net(linear, output_bufs)
state = get_state_dict(model)
weight_names = {id(x.uop.base.realized): name for name, x in state.items()}
weight_names = {(id(b), b.offset, b.size, b.dtype): name for name, x in state.items() if (b:=x.uop.base.realized) is not None}
input_names = [f"input{i}" for i in range(len(inputs))]
output_names = [f"output{i}" for i in range(len(output_bufs))]

View File

@@ -2,6 +2,8 @@ import unittest
from extra.export_model import export_model, EXPORT_SUPPORTED_DEVICE
from tinygrad.tensor import Tensor
from tinygrad.device import Device
from tinygrad.nn import Linear
from tinygrad.nn.state import get_state_dict
from tinygrad import dtypes
import json
@@ -66,5 +68,15 @@ class TextModelExportWebGPU(unittest.TestCase):
self.assertIn(f"const resultBuffer{i} = new {expected_arr_prefix}Array(gpuReadBuffer{i}.size/{dt.itemsize});", prg)
self.assertIn(f"resultBuffer{i}.set(new {expected_arr_prefix}Array(gpuReadBuffer{i}.getMappedRange()));", prg)
def test_weights_bound_to_safetensor(self):
# regression test: every weight ended up as createEmptyBuf (zero-init) instead of createWeightBuf
class MyModel:
def __init__(self): self.fc1, self.fc2 = Linear(4, 8), Linear(8, 2)
def forward(self, x): return self.fc2(self.fc1(x).relu())
model = MyModel()
for t in get_state_dict(model).values(): t.realize()
prg, _, _, _ = export_model(model, "webgpu", Tensor.randn(1, 4))
self.assertEqual(prg.count("createWeightBuf("), len(get_state_dict(model)))
if __name__ == '__main__':
unittest.main()