mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
webgpu: fix weight lookup in export_model after compile_net key change (#15919)
* fix lookup site in export_model_webgpu after refactoring webgpu (sd): fix export_model weight lookup after compile_net changes fix lookup site in export_model_webgpu after refactoring * add regression test
This commit is contained in:
@@ -114,7 +114,7 @@ if __name__ == "__main__":
|
||||
linear, output_bufs = jit_model(step, *step.input)
|
||||
functions, statements, bufs, _ = compile_net(linear, output_bufs)
|
||||
state = get_state_dict(model)
|
||||
weights = {id(x.uop.base.realized): name for name, x in state.items()}
|
||||
weights = {(id(b), b.offset, b.size, b.dtype): name for name, x in state.items() if (b:=x.uop.base.realized) is not None}
|
||||
kernel_code = '\n\n'.join([f"const {key} = `{fixup_code(code, key)}`;" for key, code in functions.items()])
|
||||
kernel_names = ', '.join([name for (name, _, _, _) in statements])
|
||||
input_names = [f"input{i}" for i in range(len(step.input))]
|
||||
|
||||
@@ -244,7 +244,7 @@ def export_model(model, target:str, *inputs, model_name: Optional[str] = "model"
|
||||
with Context(JIT=2, CPU_COUNT=1): linear, output_bufs = jit_model(model, *inputs)
|
||||
functions, statements, bufs, bufs_to_save = compile_net(linear, output_bufs)
|
||||
state = get_state_dict(model)
|
||||
weight_names = {id(x.uop.base.realized): name for name, x in state.items()}
|
||||
weight_names = {(id(b), b.offset, b.size, b.dtype): name for name, x in state.items() if (b:=x.uop.base.realized) is not None}
|
||||
input_names = [f"input{i}" for i in range(len(inputs))]
|
||||
output_names = [f"output{i}" for i in range(len(output_bufs))]
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@ import unittest
|
||||
from extra.export_model import export_model, EXPORT_SUPPORTED_DEVICE
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.nn import Linear
|
||||
from tinygrad.nn.state import get_state_dict
|
||||
from tinygrad import dtypes
|
||||
import json
|
||||
|
||||
@@ -66,5 +68,15 @@ class TextModelExportWebGPU(unittest.TestCase):
|
||||
self.assertIn(f"const resultBuffer{i} = new {expected_arr_prefix}Array(gpuReadBuffer{i}.size/{dt.itemsize});", prg)
|
||||
self.assertIn(f"resultBuffer{i}.set(new {expected_arr_prefix}Array(gpuReadBuffer{i}.getMappedRange()));", prg)
|
||||
|
||||
def test_weights_bound_to_safetensor(self):
|
||||
# regression test: every weight ended up as createEmptyBuf (zero-init) instead of createWeightBuf
|
||||
class MyModel:
|
||||
def __init__(self): self.fc1, self.fc2 = Linear(4, 8), Linear(8, 2)
|
||||
def forward(self, x): return self.fc2(self.fc1(x).relu())
|
||||
model = MyModel()
|
||||
for t in get_state_dict(model).values(): t.realize()
|
||||
prg, _, _, _ = export_model(model, "webgpu", Tensor.randn(1, 4))
|
||||
self.assertEqual(prg.count("createWeightBuf("), len(get_state_dict(model)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user