This commit is contained in:
George Hotz
2026-06-03 12:09:46 -07:00
parent 460c50710c
commit e9569b8799
3 changed files with 15 additions and 12 deletions

View File

@@ -204,7 +204,8 @@ class TestLocalAccess(unittest.TestCase):
out = Device[Device.DEFAULT].renderer.render(uops)
# half is supported in wgsl, so it doesn't have to be packed
corrected_size = size//(4//dtype.itemsize) if dtype != dtypes.half else size
self.assertIn(f"temp0: array<{Device[Device.DEFAULT].renderer.buf_map(dtype)},{corrected_size}>;", out)
# temp0: array<{Device[Device.DEFAULT].renderer.buf_map(dtype)},{corrected_size}>;
self.assertIn(f",{corrected_size}>;", out)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
@unittest.skip("tinygrad doesn't support this behavior")

View File

@@ -269,6 +269,7 @@ _nload_img = nir_instr(intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2
lambda b,img,idx_y,idx_x,dtype: mesa.nir_intrinsic_instr_create(b.shader, g("nir_intrinsic_image_load")))
class IR3Renderer(NIRRenderer, OpenCLRenderer):
new_style = False
has_aux = True
def nload_img(ctx,img,idx_y,idx_x):

View File

@@ -2,8 +2,8 @@ from __future__ import annotations
import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys, subprocess, struct
assert sys.platform != 'win32'
from tinygrad.device import BufferSpec, Compiled, Allocator, Compiler
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.uop.ops import Ops, UOp
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import Ops, UOp, AddrSpace
from tinygrad.helpers import getenv, round_up, mv_address, to_mv, cpu_objdump, system, DEBUG, suppress_finalizing, Target
from tinygrad.renderer.cstyle import ClangRenderer
from tinygrad.runtime.autogen import libc, qcom_dsp
@@ -59,13 +59,14 @@ class DSPRenderer(ClangRenderer):
'HAP_power_set((void*)handle, (void*)&req);']
msrc += ['if ((sc>>24) != 2) return 0;']
msrc += [f'int sz_or_val_{i} = ((int*)pra[0].buf.pv)[{i}];' for i,b in enumerate(bufs)]
msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(b[1][0].dtype, PtrDType)]
msrc += [f'void *buf_{i} = HAP_mmap(0,sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+off{i};' for i,b in enumerate(bufs)
if isinstance(b[1][0].dtype, PtrDType)]
msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if b[1][0].addrspace == AddrSpace.GLOBAL]
msrc += [f'void *buf_{i} = HAP_mmap(0,sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+off{i};'
for i,b in enumerate(bufs) if b[1][0].addrspace == AddrSpace.GLOBAL]
msrc += ["unsigned long long start = HAP_perf_get_time_us();"]
msrc += [f"{function_name}({', '.join([(f'buf_{i}' if isinstance(b[1][0].dtype, PtrDType) else f'sz_or_val_{i}') for i,b in enumerate(bufs)])});"]
msrc += [f"{function_name}({', '.join([(f'buf_{i}' if b[1][0].addrspace == AddrSpace.GLOBAL else f'sz_or_val_{i}')
for i,b in enumerate(bufs)])});"]
msrc += ["*(unsigned long long *)(pra[2].buf.pv) = HAP_perf_get_time_us() - start;"]
msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(b[1][0].dtype, PtrDType)]
msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if b[1][0].addrspace == AddrSpace.GLOBAL]
msrc += ["return 0; }"]
return '\n'.join(msrc)
@@ -280,18 +281,18 @@ class MockDSPRenderer(DSPRenderer):
# control register 21 is HEX_REG_QEMU_INSN_CNT, 0x6a15c000 loads it
msrc = [mockdsp_boilerplate, 'void _start(void) {']
for i,b in enumerate(bufs):
if isinstance(b[1][0].dtype, PtrDType):
sz = b[1][0].dtype.size*b[1][0].dtype.itemsize
if b[1][0].addrspace == AddrSpace.GLOBAL:
sz = b[1][0].max_numel()*b[1][0].dtype.itemsize
# for loop for big reads
msrc.append(f"void *buf{i} = mmap2(0, {sz}, 3, 0x21, -1, 0); for(int rd = 0; rd < {sz}; rd += read(0, buf{i}+rd, {sz}-rd));")
else:
msrc.append(f"unsigned int val{i}; read(0, &val{i}, 4);")
msrc.append("unsigned int st = inscount();")
params = [(f'(void*)buf{i}' if isinstance(b[1][0].dtype, PtrDType) else f'val{i}') for i,b in enumerate(bufs)]
params = [(f'(void*)buf{i}' if b[1][0].addrspace == AddrSpace.GLOBAL else f'val{i}') for i,b in enumerate(bufs)]
msrc.append(f"{function_name}({', '.join(params)});")
msrc.append("unsigned int et = inscount() - st; write(1, &et, sizeof(et));")
for i,b in enumerate(bufs):
if isinstance(b[1][0].dtype, PtrDType): msrc.append(f"write(1, buf{i}, {b[1][0].dtype.size*b[1][0].dtype.itemsize});")
if b[1][0].addrspace == AddrSpace.GLOBAL: msrc.append(f"write(1, buf{i}, {b[1][0].max_numel()*b[1][0].dtype.itemsize});")
msrc.append('exit(0); }')
return '\n'.join(msrc)