From 434cfa96a3c83bc9a90bf5d17ae4f51d634a3c42 Mon Sep 17 00:00:00 2001 From: Christopher Milan Date: Fri, 29 May 2026 14:11:16 -0700 Subject: [PATCH 1/7] ci: no fetch in backend tests (#16438) should make for less actions cache thrashing --- extra/torch_backend/test.py | 4 ++-- test/backend/test_arange.py | 4 ++-- test/backend/test_schedule.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/extra/torch_backend/test.py b/extra/torch_backend/test.py index ef3043569a..f3a1bcc93f 100644 --- a/extra/torch_backend/test.py +++ b/extra/torch_backend/test.py @@ -240,9 +240,9 @@ class TestTorchBackend(unittest.TestCase): np.testing.assert_equal(result.cpu().numpy(), [3., 3., 2.]) def test_mnist_index(self): + # from tinygrad.nn.datasets import mnist + X_train, Y_train = Tensor.randint(60000, 1, 28, 28, dtype='uchar').realize(), Tensor.randint(60000, dtype='uchar').realize() GlobalCounters.reset() - from tinygrad.nn.datasets import mnist - X_train, Y_train, _, _ = mnist() X_train = torch.tensor(X_train.float().numpy(), device=device) Y_train = torch.tensor(Y_train.cast('int64').numpy(), device=device) samples = torch.randint(0, X_train.shape[0], (32,)) diff --git a/test/backend/test_arange.py b/test/backend/test_arange.py index faa8b48137..1e223f797d 100644 --- a/test/backend/test_arange.py +++ b/test/backend/test_arange.py @@ -125,8 +125,8 @@ class TestIndexing(unittest.TestCase): def test_index_mnist(self, noopt=1, op_limit=512*784*13, split_reduceop=0): # WEBGPU generates more ops due to bitpacking of < 4-byte dtypes if Device.DEFAULT == "WEBGPU": op_limit *= 15 - from tinygrad.nn.datasets import mnist - X_train, Y_train, _, _ = mnist() + # from tinygrad.nn.datasets import mnist + X_train, Y_train = Tensor.randint(DSET, 1, 28, 28, dtype='uchar').realize(), Tensor.randint(DSET, dtype='uchar').realize() with Context(NOOPT=noopt, SPLIT_REDUCEOP=split_reduceop): samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0]).realize() GlobalCounters.reset() diff --git a/test/backend/test_schedule.py b/test/backend/test_schedule.py index 5aa8a7a5f3..3f1fffafe0 100644 --- a/test/backend/test_schedule.py +++ b/test/backend/test_schedule.py @@ -1059,9 +1059,9 @@ class TestSchedule(unittest.TestCase): self.assertEqual(b.tolist(), [False, False]) def test_mnist_val(self): - from tinygrad.nn.datasets import mnist + # from tinygrad.nn.datasets import mnist import torch - _, Y_train, _, _ = mnist() + Y_train = Tensor.randint(60000, dtype='uchar').realize() samples = Tensor.randint(BS:=getenv("BS", 512), high=cast(int,Y_train.shape[-1])).realize() yt = Tensor.randn(BS, 10).realize() loss = yt.sparse_categorical_crossentropy(Y_train[samples]) From ef50a4969367115b6e51c84b5ba39a477209ae60 Mon Sep 17 00:00:00 2001 From: Christopher Milan Date: Fri, 29 May 2026 14:40:32 -0700 Subject: [PATCH 2/7] ci: macos dev matrix (#16436) --- .github/workflows/test.yml | 92 +++++++++++++------------------------- 1 file changed, 30 insertions(+), 62 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 84632d77ad..aa5c8b0f84 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -817,74 +817,42 @@ jobs: - name: Run process replay tests uses: ./.github/actions/process-replay - osxwebgpu: - name: MacOS (WebGPU) - runs-on: macos-14 - timeout-minutes: 10 + testmacos: + strategy: + fail-fast: false + matrix: + dev: + - 'CPU:CLANG' + - 'CPU:LLVM' + - 'CPU:LVP' + - 'METAL' + - 'WEBGPU' + + name: MacOS (DEV=${{ matrix.dev }}) + runs-on: macos-15 + timeout-minutes: 20 steps: - name: Checkout Code uses: actions/checkout@v6 - name: Setup Environment uses: ./.github/actions/setup-tinygrad with: - key: osx-webgpu - deps: testing - webgpu: 'true' - - name: Build WEBGPU Efficientnet - run: DEV=WEBGPU WEBGPU_BACKEND="WGPUBackendType_Metal" python3 -m examples.compile_efficientnet - - name: Run selected webgpu tests - run: DEV=WEBGPU WEBGPU_BACKEND="WGPUBackendType_Metal" python3 -m pytest -n=auto test/backend --durations=20 - #- name: Clean npm cache - # run: npm cache clean --force - #- name: Install Puppeteer - # run: npm install puppeteer - # this is also flaky - #- name: Run WEBGPU Efficientnet - # run: node test/web/test_webgpu.js - # this is flaky - #- name: Run VIZ tests as external package - # run: | - # mkdir $GITHUB_WORKSPACE/test_dir - # cd $GITHUB_WORKSPACE/test_dir - # python -m venv venv - # source venv/bin/activate - # pip install $GITHUB_WORKSPACE - # cp $GITHUB_WORKSPACE/test/web/test_viz.js . - # node test_viz.js - - name: Test ONNX Runner (WEBGPU) - run: DEV=WEBGPU python3 test/external/external_test_onnx_runner.py - - osxtests: - strategy: - fail-fast: false - matrix: - backend: [metal, llvm, cpu, lvp] - name: MacOS (${{ matrix.backend }}) - runs-on: macos-15 - timeout-minutes: 20 - steps: - - name: Checkout Code - uses: actions/checkout@v6 - - name: Setup Environment - uses: ./.github/actions/setup-tinygrad - with: - key: macos-${{ matrix.backend }}-minimal - deps: testing_unit - llvm: ${{ matrix.backend == 'llvm' || matrix.backend == 'lvp' }} - mesa: ${{ matrix.backend == 'lvp' && 'cpu' }} - - name: Set env - run: printf "${{ matrix.backend == 'llvm' && 'DEV=CPU:LLVM' || matrix.backend == 'cpu' && 'DEV=CPU\nCPU_COUNT=2' || matrix.backend == 'metal' && 'DEV=METAL' || matrix.backend == 'lvp' && 'DEV=CPU:LVP' }}" >> $GITHUB_ENV - - name: Check Device.DEFAULT and print some source - run: | - python -c "from tinygrad import Device; assert Device.DEFAULT == {'LLVM':'CPU','LVP':'CPU'}.get(x:='${{ matrix.backend }}'.upper(), x), Device.DEFAULT" - DEBUG=4 python3 test/test_tiny.py TestTiny.test_plus - - name: Run pytest (${{ matrix.backend }}) - run: python3 -m pytest -n=auto test/backend --durations=20 - - name: Run process replay tests - uses: ./.github/actions/process-replay - - name: Run macOS-specific unit test - if: matrix.backend == 'llvm' - run: python3 -m pytest test/unit/test_disk_tensor.py::TestDiskTensor::test_copy_to_cpu_not_truncated test/unit/test_cpu.py + key: macos-${{ matrix.dev }} + deps: testing_unit + python-version: '3.12' + llvm: ${{ contains(matrix.dev, 'LLVM') || contains(matrix.dev, 'LVP') }} + mesa: ${{ contains(matrix.dev, 'LVP') && 'cpu' || 'false' }} + webgpu: ${{ matrix.dev == 'WEBGPU' }} + - name: Set env + run: printf "DEV=${{ matrix.dev }}${{ matrix.dev == 'CPU:CLANG' && '\nCPU_COUNT=2' || '' }}" >> $GITHUB_ENV + - name: Check Device.DEFAULT and print some source + run: | + python -c "from tinygrad import Device; from tinygrad.helpers import Target; assert Device.DEFAULT == Target.parse('${{ matrix.dev }}').device" + DEBUG=4 python test/test_tiny.py TestTiny.test_plus + - name: Run backend tests + run: python -m pytest -n=auto test/backend --durations=20 + - name: Run process replay tests + uses: ./.github/actions/process-replay # ****** Windows Tests ****** From 8ac62b28e509884b5f36e5ce2393fc13eacacdf2 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 29 May 2026 17:59:47 -0400 Subject: [PATCH 3/7] fix AffineGrid fusion (#16439) --- tinygrad/nn/onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tinygrad/nn/onnx.py b/tinygrad/nn/onnx.py index b117f286d7..e6cd3fba1f 100644 --- a/tinygrad/nn/onnx.py +++ b/tinygrad/nn/onnx.py @@ -1000,7 +1000,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT if align_corners: return Tensor.linspace(-1, 1, steps, device=theta.device) return Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device) grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims)) - base_grid = Tensor.stack(*reversed(grids), Tensor.ones_like(grids[0], device=theta.device), dim=-1) + base_grid = Tensor.stack(*reversed(grids), grids[0].const_like(1), dim=-1) base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1) return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1) From d943493b79eb5fba0c21860d1060f69d85ac5a25 Mon Sep 17 00:00:00 2001 From: Christopher Milan Date: Fri, 29 May 2026 16:20:31 -0700 Subject: [PATCH 4/7] ci: remove duplicate op compile test (#16441) --- .github/workflows/test.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index aa5c8b0f84..e5bb498863 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -404,8 +404,6 @@ jobs: - name: Test openpilot model kernel count and gate usage run: | ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1468 ALLOWED_GATED_READ_IMAGE=18 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916 - - name: Test openpilot CL compile fp16 - run: FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916 - name: Test openpilot CL compile fp32 (test correctness) run: | DEV=CL IMAGE=1 SELFTEST=1 python examples/openpilot/compile3.py https://github.com/haraschax/filedump/raw/refs/heads/master/driving_vision_fp32.onnx From c23652e4865648c8b2b76ff63e16d52feb6609de Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Fri, 29 May 2026 21:00:37 -0400 Subject: [PATCH 5/7] llama: minimize peak init mem (#16440) --- examples/mlperf/model_train.py | 5 +---- examples/mlperf/models/flat_llama.py | 27 ++++++++++++--------------- 2 files changed, 13 insertions(+), 19 deletions(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index fbe0e7ffb0..f294bb3071 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -1419,10 +1419,7 @@ def train_llama3(): for p in optim.params: grad_dtype = dtypes.bfloat16 if p.dtype == FP8_DTYPE else p.dtype - if isinstance(p.device, tuple) and p.uop.axis is not None: - p.grad = Tensor.zeros(p.shape, dtype=grad_dtype, device=p.device[0]).shard_(p.device, axis=p.uop.axis).contiguous() - else: - p.grad = Tensor.zeros(p.shape, dtype=grad_dtype, device=p.device).contiguous() + p.grad = p.zeros_like(dtype=grad_dtype).contiguous() grads = [p.grad for p in optim.params] scheduler = CosineAnnealingLRWithWarmup(optim, opt_base_learning_rate, opt_end_learning_rate, opt_learning_rate_warmup_steps, opt_learning_rate_decay_steps) diff --git a/examples/mlperf/models/flat_llama.py b/examples/mlperf/models/flat_llama.py index d8f032c6a6..12240d068c 100644 --- a/examples/mlperf/models/flat_llama.py +++ b/examples/mlperf/models/flat_llama.py @@ -222,14 +222,19 @@ class FlatTransformer: for v in get_parameters(self): v.shard_(device, axis=None) else: # flat per-layer weights: axis 0 is n_layers, so shard axes are +1 vs per-layer Transformer - self.wqkv.shard_(device, axis=1).realize() # (n_layers, out, dim) shard out - self.wo.shard_(device, axis=2).realize() # (n_layers, dim, in) shard in + def _shard_fp8(name:str, axis:int): + getattr(self, name).shard_(device, axis=axis) + self._fp8_inv_scale[name] = self._fp8_inv_scale[name].to(device).contiguous().is_param_(False) + self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].to(device).contiguous().is_param_(False) + Tensor.realize(getattr(self, name), self._fp8_inv_scale[name], self._fp8_next_inv_scale[name]) + _shard_fp8("wqkv", 1) # (n_layers, out, dim) shard out + _shard_fp8("wo", 2) # (n_layers, dim, in) shard in if SPLIT_W13: - self.w1.shard_(device, axis=1).realize() - self.w3.shard_(device, axis=1).realize() + _shard_fp8("w1", 1) + _shard_fp8("w3", 1) else: - self.w13.shard_(device, axis=1).realize() # (n_layers, hidden*2, dim) shard out - self.w2.shard_(device, axis=2).realize() # (n_layers, dim, hidden) shard in + _shard_fp8("w13", 1) # (n_layers, hidden*2, dim) shard out + _shard_fp8("w2", 2) # (n_layers, dim, hidden) shard in self.attention_norm.shard_(device, axis=None).realize() self.ffn_norm.shard_(device, axis=None).realize() self.norm.weight.shard_(device, axis=None).realize() @@ -240,10 +245,6 @@ class FlatTransformer: for name in amax_dict: for i in range(len(amax_dict[name])): amax_dict[name][i] = amax_dict[name][i].to(device).contiguous().is_param_(False) - for name in self._fp8_inv_scale: - self._fp8_inv_scale[name] = self._fp8_inv_scale[name].to(device).contiguous().is_param_(False) - for name in self._fp8_next_inv_scale: - self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].to(device).contiguous().is_param_(False) def __call__(self, tokens:Tensor, save:bool=True): h = self.tok_embeddings(tokens) @@ -325,11 +326,7 @@ if __name__ == "__main__": # preallocate all the grad buffers and zero them out grad_dtype = lambda x: dtypes.bfloat16 if x.dtype in dtypes.fp8s else x.dtype - def _make_grad(x): - if isinstance(x.device, tuple) and x.uop.axis is not None: - return Tensor.zeros(x.shape, dtype=grad_dtype(x), device=x.device[0]).shard_(x.device, axis=x.uop.axis).contiguous() - return Tensor.zeros(x.shape, dtype=grad_dtype(x), device=x.device).contiguous() - grads = {x:_make_grad(x) for x in state.values() if x.is_param} + grads = {x:x.zeros_like(dtype=grad_dtype(x)).contiguous() for x in state.values() if x.is_param} fp8_amax = [t for ts in model._fp8_amax.values() for t in ts] fp8_grad_amax = [t for ts in model._fp8_grad_amax.values() for t in ts] From c377d0149108da2f3a7c94a5e62681222f5fbc4b Mon Sep 17 00:00:00 2001 From: Christopher Milan Date: Fri, 29 May 2026 18:16:56 -0700 Subject: [PATCH 6/7] ci: run dsp on tinygrad[testing] (#16442) --- .github/workflows/test.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e5bb498863..524af917b4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -556,8 +556,7 @@ jobs: uses: ./.github/actions/setup-tinygrad with: key: dsp-minimal - deps: testing_unit - pydeps: "onnx==1.18.0 onnxruntime ml_dtypes" + deps: testing llvm: "true" qemu: "true" - name: Set MOCKDSP env From cf55aaf01f806a72cd745c4ebe81275eab744304 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 29 May 2026 19:13:51 -0700 Subject: [PATCH 7/7] python prg is pkl uops (#16443) * python prg is pkl uops * refactor to use uop * refactor to u. --- tinygrad/runtime/ops_python.py | 146 ++++++++++++++++----------------- 1 file changed, 72 insertions(+), 74 deletions(-) diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 28b9089012..30ffa4625c 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -41,41 +41,42 @@ def generic_wmma_helper(inp, warp_size, WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_ class PythonProgram: def __init__(self, name:str, lib:bytes, **kwargs): - self.uops: list[tuple[Ops, DType, list[int], Any]] = pickle.loads(lib) + self.uops: list[UOp] = pickle.loads(lib) + self.uop_to_index: dict[UOp, int] = {u:i for i,u in enumerate(self.uops)} + self.loop_ends: dict[UOp, int] = {u.src[1]:i for i, u in enumerate(self.uops) if u.op == Ops.END} def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw): st = time.perf_counter() warp = list(itertools.product(*[range(x) for x in local_size[::-1]])) warp_size = len(warp) void_ops = {Ops.END, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.GROUP, Ops.STORE} - loop_ends: dict[int, int] = {srcs[1]:i for i, (uop, _, srcs, _) in enumerate(self.uops) if uop == Ops.END} for idxs in itertools.product(*[range(x) for x in global_size[::-1]]): - values: dict[int, Any] = {} + values: dict[UOp, Any] = {} pbufs: list[memoryview] = list(bufs) pvals: list[int] = list(vals) exec_masks = [[True] * warp_size] i = 0 while i < len(self.uops): - uop, dtype, srcs, arg = self.uops[i] - src_values = [values[v] for v in srcs if self.uops[v][0] not in void_ops] - src_dtypes = [self.uops[v][1] for v in srcs if self.uops[v][0] not in void_ops] - if getenv("TRACE"): print(i, uop, dtype, arg, src_values, src_dtypes) - if uop is Ops.END: - i = srcs[1] + u = self.uops[i] + src_values = [values[v] for v in u.src if v.op not in void_ops] + src_dtypes = [v.dtype for v in u.src if v.op not in void_ops] + if getenv("TRACE"): print(i, u.op, u.dtype, u.arg, src_values, src_dtypes) + if u.op is Ops.END: + i = self.uop_to_index[u.src[1]] continue - if uop is Ops.IF: + if u.op is Ops.IF: exec_masks.append([x and y for x,y in zip(exec_masks[-1], src_values[0])]) i += 1 continue - if uop is Ops.ENDIF: + if u.op is Ops.ENDIF: exec_masks.pop() i += 1 continue - if uop in (Ops.BARRIER, Ops.SINK, Ops.NOOP, Ops.GROUP): + if u.op in (Ops.BARRIER, Ops.SINK, Ops.NOOP, Ops.GROUP): # in the python emulator, the warp is always in sync i += 1 continue - assert dtype is not None, f"{uop} is missing a dtype" - if uop is Ops.STORE: + assert u.dtype is not None, f"{u.op} is missing a dtype" + if u.op is Ops.STORE: assert len(src_values) == 2, f"STORE must be lowered to 2 srcs, got {len(src_values)}" store_gate = exec_masks[-1] for j,val in enumerate(src_values[1] if src_dtypes[1].count > 1 else [src_values[1]]): @@ -83,25 +84,25 @@ class PythonProgram: if g: _store(m, o+j, v, src_dtypes[1].scalar()) i += 1 continue - if uop is Ops.AFTER: values[i] = src_values[0] - elif uop in {Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}: - assert isinstance(dtype, PtrDType), dtype - storage_fmt = storage_fmt_for_dtype(dtype.base.scalar()) - if storage_fmt is None: raise RuntimeError(f"{dtype=} is not supported") + if u.op is Ops.AFTER: values[u] = src_values[0] + elif u.op in {Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}: + assert isinstance(u.dtype, PtrDType), u.dtype + storage_fmt = storage_fmt_for_dtype(u.dtype.base.scalar()) + if storage_fmt is None: raise RuntimeError(f"dtype={u.dtype} is not supported") if TYPE_CHECKING or sys.version_info < (3, 12): assert storage_fmt != "e" - if uop is Ops.DEFINE_REG: + if u.op is Ops.DEFINE_REG: # REGs are per thread - values[i] = [memoryview(bytearray(dtype.size*dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)] + values[u] = [memoryview(bytearray(u.dtype.size*u.dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)] else: - buf = memoryview(bytearray(dtype.size*dtype.itemsize)) if uop is not Ops.PARAM else pbufs.pop(0) - values[i] = [buf.cast(storage_fmt)] * warp_size - elif uop is Ops.DEFINE_VAR: - values[i] = [pvals.pop(0)] * warp_size - elif uop is Ops.SPECIAL: - if arg[0] == 'g': values[i] = [idxs[2-int(arg[-1])]] * warp_size - elif arg[0] == 'l': values[i] = [x[2-int(arg[-1])] for x in warp] - elif uop is Ops.CONST: values[i] = [arg] * warp_size - elif uop is Ops.INDEX: + buf = memoryview(bytearray(u.dtype.size*u.dtype.itemsize)) if u.op is not Ops.PARAM else pbufs.pop(0) + values[u] = [buf.cast(storage_fmt)] * warp_size + elif u.op is Ops.DEFINE_VAR: + values[u] = [pvals.pop(0)] * warp_size + elif u.op is Ops.SPECIAL: + if u.arg[0] == 'g': values[u] = [idxs[2-int(u.arg[-1])]] * warp_size + elif u.arg[0] == 'l': values[u] = [x[2-int(u.arg[-1])] for x in warp] + elif u.op is Ops.CONST: values[u] = [u.arg] * warp_size + elif u.op is Ops.INDEX: ret:list = [] if isinstance(src_dtypes[0], ImageDType): assert len(src_values) == 3, "image index must be 3 srcs" @@ -111,33 +112,33 @@ class PythonProgram: else: assert len(src_values) == 2, "non-image index must be 2 srcs" for m,o in zip(*src_values): ret.append((m,o)) - values[i] = ret - elif uop is Ops.CAST and isinstance(dtype, PtrDType): - values[i] = src_values[0] - elif uop is Ops.RANGE: - if i not in values: values[i] = [0] * warp_size + values[u] = ret + elif u.op is Ops.CAST and isinstance(u.dtype, PtrDType): + values[u] = src_values[0] + elif u.op is Ops.RANGE: + if u not in values: values[u] = [0] * warp_size else: - for j in range(len(values[i])): - values[i][j] += 1 - if values[i][0] == src_values[0][0]: - del values[i] - i = loop_ends[i] + 1 + for j in range(len(values[u])): + values[u][j] += 1 + if values[u][0] == src_values[0][0]: + del values[u] + i = self.loop_ends[u] + 1 continue - elif uop is Ops.STACK: values[i] = src_values - elif uop is Ops.BITCAST: values[i] = [bitcast(x, src_dtypes[0], dtype) for x in src_values[0]] - elif uop is Ops.CAST: - values[i] = [truncate.get(dtype, lambda dt: dt)(dtype.const(x)) for x in src_values[0]] - elif uop is Ops.LOAD: - if dtype.count > 1: - values[i] = [load([src_values[i][j] if i != 0 and src_dtypes[i].count > 1 else src_values[i] \ - for i in range(len(src_values))], j, dtype.scalar()) for j in range(dtype.count)] + elif u.op is Ops.STACK: values[u] = src_values + elif u.op is Ops.BITCAST: values[u] = [bitcast(x, src_dtypes[0], u.dtype) for x in src_values[0]] + elif u.op is Ops.CAST: + values[u] = [truncate.get(u.dtype, lambda dt: dt)(u.dtype.const(x)) for x in src_values[0]] + elif u.op is Ops.LOAD: + if u.dtype.count > 1: + values[u] = [load([src_values[k][j] if k != 0 and src_dtypes[k].count > 1 else src_values[k] \ + for k in range(len(src_values))], j, u.dtype.scalar()) for j in range(u.dtype.count)] else: - values[i] = load(src_values, 0, dtype) - elif uop is Ops.GEP: values[i] = src_values[0][get_single_element(arg)] - elif uop is Ops.WMMA: - first_src_dtype = self.uops[srcs[0]][1] + values[u] = load(src_values, 0, u.dtype) + elif u.op is Ops.GEP: values[u] = src_values[0][get_single_element(u.arg)] + elif u.op is Ops.WMMA: + first_src_dtype = u.src[0].dtype assert isinstance(first_src_dtype, DType) # mypy - dims, dtype_in, device, threads = arg[1], first_src_dtype.scalar(), arg[4], arg[5] + dims, dtype_in, device, threads = u.arg[1], first_src_dtype.scalar(), u.arg[4], u.arg[5] wmma_helper = functools.partial(generic_wmma_helper, src_values, warp_size) # TODO: refactor these to a shared TensorCoreLayout if device == "METAL": @@ -145,17 +146,17 @@ class PythonProgram: def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16] # (i, j), C, D (2 elements on 32 threads): row major same as A/B def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4) - values[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map) + values[u] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map) elif device == "AMD" and threads == 64: def a_elem(x, k, row, goff): return x[k%(dims[2]//4)][goff + (k//(dims[2]//4))*16 + row] def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order def c_map(lane, elem): return (lane%16, (lane//16)*4 + elem) - values[i] = wmma_helper(64, dims[2], len(src_values[0]), len(src_values[1]), len(src_values[2]), a_elem, b_elem, c_map) + values[u] = wmma_helper(64, dims[2], len(src_values[0]), len(src_values[1]), len(src_values[2]), a_elem, b_elem, c_map) elif device == "AMD" and len(src_values[0]) == 8: # RDNA4 def a_elem(x, k, row, goff): return x[k - [0, 4, 4, 8][k//4]][goff + row + [0, 16, 0, 16][k//4]] def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) def c_map(lane, elem): return (lane%16, (lane//16)*8 + elem) - values[i] = wmma_helper(32, 16, 8, 8, 8, a_elem, b_elem, c_map) + values[u] = wmma_helper(32, 16, 8, 8, 8, a_elem, b_elem, c_map) elif device == "AMD": # A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15 def a_elem(x, k, row, goff): @@ -164,7 +165,7 @@ class PythonProgram: # B (16 elements on 32 threads): row major, lane 16-32 == lane 0-15 def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major - values[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map) + values[u] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map) elif device == "CUDA": # (col, row) given (lane, elem) for C & D (4 elements on 32 threads); shared by all tc shapes with M=16 N=8 def c_map(lane, elem): return (elem%2 + (lane%4)*2, lane//4 + (elem//2)*8) @@ -172,24 +173,24 @@ class PythonProgram: if dims == (8,16,16): def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2 + (k//8)*4][goff + (k//2)%4 + (row%8)*4] def b_elem(x, col, k, goff): return x[k%2 + (k//8)*2][goff + (k//2)%4 + col*4] - values[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map) + values[u] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map) elif dims == (8,16,32): def a_elem(x, k, row, goff): return x[k%4 + (row//8)*4 + (k//16)*8][goff + (k//4)%4 + (row%8)*4] def b_elem(x, col, k, goff): return x[k%4 + (k//16)*4][goff + (k//4)%4 + col*4] - values[i] = wmma_helper(32, 32, 16, 8, 4, a_elem, b_elem, c_map) + values[u] = wmma_helper(32, 32, 16, 8, 4, a_elem, b_elem, c_map) elif dims == (8,16,8) and dtype_in == dtypes.half: def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2][goff + k//2 + (row%8)*4] def b_elem(x, col, k, goff): return x[k%2][goff + k//2 + col*4] - values[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map) + values[u] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map) elif dims == (8,16,8) and dtype_in == dtypes.float: def a_elem(x, k, row, goff): return x[(k//4)*2 + row//8][goff + k%4 + (row%8)*4] def b_elem(x, col, k, goff): return x[k//4][goff + k%4 + col*4] - values[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map) + values[u] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map) - else: raise NotImplementedError(f"unimplemented tensor core {arg}") + else: raise NotImplementedError(f"unimplemented tensor core {u.arg}") elif device == "INTEL": # A (16 elements on 8 threads) def a_elem(x, k, row, goff): return x[k%2+row*2][goff+k//2] @@ -197,17 +198,17 @@ class PythonProgram: def b_elem(x, col, k, goff): return x[k][goff+col] # C, D (8 elements on 8 threads) def c_map(lane, elem): return (lane, elem) - values[i] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map) + values[u] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map) elif device == "CPU": def elem(x, col, row, _): return x[col+row][0] # k is always 0 def c_map(lane, elem): return (elem%16, elem//16) - values[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map) - else: raise NotImplementedError(f"unimplemented tensor core {arg}") - elif uop in GroupOp.ALU: - assert all_same([len(x) for x in src_values]), f"{[len(x) for x in src_values]} doesn't match on {uop}" - assert all_same([dtype] + src_dtypes) or uop in {*GroupOp.Comparison, Ops.WHERE}, f"dtype mismatch on {uop}" - values[i] = [exec_alu(uop, dtype, p) for p in zip(*src_values)] - assert i in values, (uop, dtype, srcs, arg) + values[u] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map) + else: raise NotImplementedError(f"unimplemented tensor core {u.arg}") + elif u.op in GroupOp.ALU: + assert all_same([len(x) for x in src_values]), f"{[len(x) for x in src_values]} doesn't match on {u.op}" + assert all_same([u.dtype] + src_dtypes) or u.op in {*GroupOp.Comparison, Ops.WHERE}, f"dtype mismatch on {u.op}" + values[u] = [exec_alu(u.op, u.dtype, p) for p in zip(*src_values)] + assert u in values, u i += 1 return time.perf_counter() - st @@ -234,10 +235,7 @@ class PythonRenderer(Renderer): elif IMAGE and not target.arch: self.target = replace(target, arch="IMAGE_PITCH_ALIGNMENT=1") else: self.target = target - def render(self, uops:list[UOp]) -> str: - # the value of SPECIAL comes from local/global_size, not form its source - lops = [(u.op, u.dtype, [uops.index(v) for v in u.src if u.op is not Ops.SPECIAL], u.arg) for u in uops] - return base64.b64encode(pickle.dumps(lops)).decode() + def render(self, uops:list[UOp]) -> str: return base64.b64encode(pickle.dumps(uops)).decode() def supported_dtypes(self): return {d for d in super().supported_dtypes() if d != dtypes.half or sys.version_info >= (3, 12)}