Merge branch 'master' into shrink_in_render

This commit is contained in:
George Hotz
2026-05-29 19:17:11 -07:00
committed by GitHub
8 changed files with 123 additions and 166 deletions

View File

@@ -404,8 +404,6 @@ jobs:
- name: Test openpilot model kernel count and gate usage
run: |
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1468 ALLOWED_GATED_READ_IMAGE=18 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
- name: Test openpilot CL compile fp16
run: FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
- name: Test openpilot CL compile fp32 (test correctness)
run: |
DEV=CL IMAGE=1 SELFTEST=1 python examples/openpilot/compile3.py https://github.com/haraschax/filedump/raw/refs/heads/master/driving_vision_fp32.onnx
@@ -558,8 +556,7 @@ jobs:
uses: ./.github/actions/setup-tinygrad
with:
key: dsp-minimal
deps: testing_unit
pydeps: "onnx==1.18.0 onnxruntime ml_dtypes"
deps: testing
llvm: "true"
qemu: "true"
- name: Set MOCKDSP env
@@ -817,74 +814,42 @@ jobs:
- name: Run process replay tests
uses: ./.github/actions/process-replay
osxwebgpu:
name: MacOS (WebGPU)
runs-on: macos-14
timeout-minutes: 10
testmacos:
strategy:
fail-fast: false
matrix:
dev:
- 'CPU:CLANG'
- 'CPU:LLVM'
- 'CPU:LVP'
- 'METAL'
- 'WEBGPU'
name: MacOS (DEV=${{ matrix.dev }})
runs-on: macos-15
timeout-minutes: 20
steps:
- name: Checkout Code
uses: actions/checkout@v6
- name: Setup Environment
uses: ./.github/actions/setup-tinygrad
with:
key: osx-webgpu
deps: testing
webgpu: 'true'
- name: Build WEBGPU Efficientnet
run: DEV=WEBGPU WEBGPU_BACKEND="WGPUBackendType_Metal" python3 -m examples.compile_efficientnet
- name: Run selected webgpu tests
run: DEV=WEBGPU WEBGPU_BACKEND="WGPUBackendType_Metal" python3 -m pytest -n=auto test/backend --durations=20
#- name: Clean npm cache
# run: npm cache clean --force
#- name: Install Puppeteer
# run: npm install puppeteer
# this is also flaky
#- name: Run WEBGPU Efficientnet
# run: node test/web/test_webgpu.js
# this is flaky
#- name: Run VIZ tests as external package
# run: |
# mkdir $GITHUB_WORKSPACE/test_dir
# cd $GITHUB_WORKSPACE/test_dir
# python -m venv venv
# source venv/bin/activate
# pip install $GITHUB_WORKSPACE
# cp $GITHUB_WORKSPACE/test/web/test_viz.js .
# node test_viz.js
- name: Test ONNX Runner (WEBGPU)
run: DEV=WEBGPU python3 test/external/external_test_onnx_runner.py
osxtests:
strategy:
fail-fast: false
matrix:
backend: [metal, llvm, cpu, lvp]
name: MacOS (${{ matrix.backend }})
runs-on: macos-15
timeout-minutes: 20
steps:
- name: Checkout Code
uses: actions/checkout@v6
- name: Setup Environment
uses: ./.github/actions/setup-tinygrad
with:
key: macos-${{ matrix.backend }}-minimal
deps: testing_unit
llvm: ${{ matrix.backend == 'llvm' || matrix.backend == 'lvp' }}
mesa: ${{ matrix.backend == 'lvp' && 'cpu' }}
- name: Set env
run: printf "${{ matrix.backend == 'llvm' && 'DEV=CPU:LLVM' || matrix.backend == 'cpu' && 'DEV=CPU\nCPU_COUNT=2' || matrix.backend == 'metal' && 'DEV=METAL' || matrix.backend == 'lvp' && 'DEV=CPU:LVP' }}" >> $GITHUB_ENV
- name: Check Device.DEFAULT and print some source
run: |
python -c "from tinygrad import Device; assert Device.DEFAULT == {'LLVM':'CPU','LVP':'CPU'}.get(x:='${{ matrix.backend }}'.upper(), x), Device.DEFAULT"
DEBUG=4 python3 test/test_tiny.py TestTiny.test_plus
- name: Run pytest (${{ matrix.backend }})
run: python3 -m pytest -n=auto test/backend --durations=20
- name: Run process replay tests
uses: ./.github/actions/process-replay
- name: Run macOS-specific unit test
if: matrix.backend == 'llvm'
run: python3 -m pytest test/unit/test_disk_tensor.py::TestDiskTensor::test_copy_to_cpu_not_truncated test/unit/test_cpu.py
key: macos-${{ matrix.dev }}
deps: testing_unit
python-version: '3.12'
llvm: ${{ contains(matrix.dev, 'LLVM') || contains(matrix.dev, 'LVP') }}
mesa: ${{ contains(matrix.dev, 'LVP') && 'cpu' || 'false' }}
webgpu: ${{ matrix.dev == 'WEBGPU' }}
- name: Set env
run: printf "DEV=${{ matrix.dev }}${{ matrix.dev == 'CPU:CLANG' && '\nCPU_COUNT=2' || '' }}" >> $GITHUB_ENV
- name: Check Device.DEFAULT and print some source
run: |
python -c "from tinygrad import Device; from tinygrad.helpers import Target; assert Device.DEFAULT == Target.parse('${{ matrix.dev }}').device"
DEBUG=4 python test/test_tiny.py TestTiny.test_plus
- name: Run backend tests
run: python -m pytest -n=auto test/backend --durations=20
- name: Run process replay tests
uses: ./.github/actions/process-replay
# ****** Windows Tests ******

View File

@@ -1419,10 +1419,7 @@ def train_llama3():
for p in optim.params:
grad_dtype = dtypes.bfloat16 if p.dtype == FP8_DTYPE else p.dtype
if isinstance(p.device, tuple) and p.uop.axis is not None:
p.grad = Tensor.zeros(p.shape, dtype=grad_dtype, device=p.device[0]).shard_(p.device, axis=p.uop.axis).contiguous()
else:
p.grad = Tensor.zeros(p.shape, dtype=grad_dtype, device=p.device).contiguous()
p.grad = p.zeros_like(dtype=grad_dtype).contiguous()
grads = [p.grad for p in optim.params]
scheduler = CosineAnnealingLRWithWarmup(optim, opt_base_learning_rate, opt_end_learning_rate, opt_learning_rate_warmup_steps, opt_learning_rate_decay_steps)

View File

@@ -222,14 +222,19 @@ class FlatTransformer:
for v in get_parameters(self): v.shard_(device, axis=None)
else:
# flat per-layer weights: axis 0 is n_layers, so shard axes are +1 vs per-layer Transformer
self.wqkv.shard_(device, axis=1).realize() # (n_layers, out, dim) shard out
self.wo.shard_(device, axis=2).realize() # (n_layers, dim, in) shard in
def _shard_fp8(name:str, axis:int):
getattr(self, name).shard_(device, axis=axis)
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].to(device).contiguous().is_param_(False)
self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].to(device).contiguous().is_param_(False)
Tensor.realize(getattr(self, name), self._fp8_inv_scale[name], self._fp8_next_inv_scale[name])
_shard_fp8("wqkv", 1) # (n_layers, out, dim) shard out
_shard_fp8("wo", 2) # (n_layers, dim, in) shard in
if SPLIT_W13:
self.w1.shard_(device, axis=1).realize()
self.w3.shard_(device, axis=1).realize()
_shard_fp8("w1", 1)
_shard_fp8("w3", 1)
else:
self.w13.shard_(device, axis=1).realize() # (n_layers, hidden*2, dim) shard out
self.w2.shard_(device, axis=2).realize() # (n_layers, dim, hidden) shard in
_shard_fp8("w13", 1) # (n_layers, hidden*2, dim) shard out
_shard_fp8("w2", 2) # (n_layers, dim, hidden) shard in
self.attention_norm.shard_(device, axis=None).realize()
self.ffn_norm.shard_(device, axis=None).realize()
self.norm.weight.shard_(device, axis=None).realize()
@@ -240,10 +245,6 @@ class FlatTransformer:
for name in amax_dict:
for i in range(len(amax_dict[name])):
amax_dict[name][i] = amax_dict[name][i].to(device).contiguous().is_param_(False)
for name in self._fp8_inv_scale:
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].to(device).contiguous().is_param_(False)
for name in self._fp8_next_inv_scale:
self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].to(device).contiguous().is_param_(False)
def __call__(self, tokens:Tensor, save:bool=True):
h = self.tok_embeddings(tokens)
@@ -325,11 +326,7 @@ if __name__ == "__main__":
# preallocate all the grad buffers and zero them out
grad_dtype = lambda x: dtypes.bfloat16 if x.dtype in dtypes.fp8s else x.dtype
def _make_grad(x):
if isinstance(x.device, tuple) and x.uop.axis is not None:
return Tensor.zeros(x.shape, dtype=grad_dtype(x), device=x.device[0]).shard_(x.device, axis=x.uop.axis).contiguous()
return Tensor.zeros(x.shape, dtype=grad_dtype(x), device=x.device).contiguous()
grads = {x:_make_grad(x) for x in state.values() if x.is_param}
grads = {x:x.zeros_like(dtype=grad_dtype(x)).contiguous() for x in state.values() if x.is_param}
fp8_amax = [t for ts in model._fp8_amax.values() for t in ts]
fp8_grad_amax = [t for ts in model._fp8_grad_amax.values() for t in ts]

View File

@@ -240,9 +240,9 @@ class TestTorchBackend(unittest.TestCase):
np.testing.assert_equal(result.cpu().numpy(), [3., 3., 2.])
def test_mnist_index(self):
# from tinygrad.nn.datasets import mnist
X_train, Y_train = Tensor.randint(60000, 1, 28, 28, dtype='uchar').realize(), Tensor.randint(60000, dtype='uchar').realize()
GlobalCounters.reset()
from tinygrad.nn.datasets import mnist
X_train, Y_train, _, _ = mnist()
X_train = torch.tensor(X_train.float().numpy(), device=device)
Y_train = torch.tensor(Y_train.cast('int64').numpy(), device=device)
samples = torch.randint(0, X_train.shape[0], (32,))

View File

@@ -125,8 +125,8 @@ class TestIndexing(unittest.TestCase):
def test_index_mnist(self, noopt=1, op_limit=512*784*13, split_reduceop=0):
# WEBGPU generates more ops due to bitpacking of < 4-byte dtypes
if Device.DEFAULT == "WEBGPU": op_limit *= 15
from tinygrad.nn.datasets import mnist
X_train, Y_train, _, _ = mnist()
# from tinygrad.nn.datasets import mnist
X_train, Y_train = Tensor.randint(DSET, 1, 28, 28, dtype='uchar').realize(), Tensor.randint(DSET, dtype='uchar').realize()
with Context(NOOPT=noopt, SPLIT_REDUCEOP=split_reduceop):
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0]).realize()
GlobalCounters.reset()

View File

@@ -1059,9 +1059,9 @@ class TestSchedule(unittest.TestCase):
self.assertEqual(b.tolist(), [False, False])
def test_mnist_val(self):
from tinygrad.nn.datasets import mnist
# from tinygrad.nn.datasets import mnist
import torch
_, Y_train, _, _ = mnist()
Y_train = Tensor.randint(60000, dtype='uchar').realize()
samples = Tensor.randint(BS:=getenv("BS", 512), high=cast(int,Y_train.shape[-1])).realize()
yt = Tensor.randn(BS, 10).realize()
loss = yt.sparse_categorical_crossentropy(Y_train[samples])

View File

@@ -1000,7 +1000,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
if align_corners: return Tensor.linspace(-1, 1, steps, device=theta.device)
return Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device)
grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims))
base_grid = Tensor.stack(*reversed(grids), Tensor.ones_like(grids[0], device=theta.device), dim=-1)
base_grid = Tensor.stack(*reversed(grids), grids[0].const_like(1), dim=-1)
base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1)
return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1)

View File

@@ -41,41 +41,42 @@ def generic_wmma_helper(inp, warp_size, WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_
class PythonProgram:
def __init__(self, name:str, lib:bytes, **kwargs):
self.uops: list[tuple[Ops, DType, list[int], Any]] = pickle.loads(lib)
self.uops: list[UOp] = pickle.loads(lib)
self.uop_to_index: dict[UOp, int] = {u:i for i,u in enumerate(self.uops)}
self.loop_ends: dict[UOp, int] = {u.src[1]:i for i, u in enumerate(self.uops) if u.op == Ops.END}
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw):
st = time.perf_counter()
warp = list(itertools.product(*[range(x) for x in local_size[::-1]]))
warp_size = len(warp)
void_ops = {Ops.END, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.GROUP, Ops.STORE}
loop_ends: dict[int, int] = {srcs[1]:i for i, (uop, _, srcs, _) in enumerate(self.uops) if uop == Ops.END}
for idxs in itertools.product(*[range(x) for x in global_size[::-1]]):
values: dict[int, Any] = {}
values: dict[UOp, Any] = {}
pbufs: list[memoryview] = list(bufs)
pvals: list[int] = list(vals)
exec_masks = [[True] * warp_size]
i = 0
while i < len(self.uops):
uop, dtype, srcs, arg = self.uops[i]
src_values = [values[v] for v in srcs if self.uops[v][0] not in void_ops]
src_dtypes = [self.uops[v][1] for v in srcs if self.uops[v][0] not in void_ops]
if getenv("TRACE"): print(i, uop, dtype, arg, src_values, src_dtypes)
if uop is Ops.END:
i = srcs[1]
u = self.uops[i]
src_values = [values[v] for v in u.src if v.op not in void_ops]
src_dtypes = [v.dtype for v in u.src if v.op not in void_ops]
if getenv("TRACE"): print(i, u.op, u.dtype, u.arg, src_values, src_dtypes)
if u.op is Ops.END:
i = self.uop_to_index[u.src[1]]
continue
if uop is Ops.IF:
if u.op is Ops.IF:
exec_masks.append([x and y for x,y in zip(exec_masks[-1], src_values[0])])
i += 1
continue
if uop is Ops.ENDIF:
if u.op is Ops.ENDIF:
exec_masks.pop()
i += 1
continue
if uop in (Ops.BARRIER, Ops.SINK, Ops.NOOP, Ops.GROUP):
if u.op in (Ops.BARRIER, Ops.SINK, Ops.NOOP, Ops.GROUP):
# in the python emulator, the warp is always in sync
i += 1
continue
assert dtype is not None, f"{uop} is missing a dtype"
if uop is Ops.STORE:
assert u.dtype is not None, f"{u.op} is missing a dtype"
if u.op is Ops.STORE:
assert len(src_values) == 2, f"STORE must be lowered to 2 srcs, got {len(src_values)}"
store_gate = exec_masks[-1]
for j,val in enumerate(src_values[1] if src_dtypes[1].count > 1 else [src_values[1]]):
@@ -83,25 +84,25 @@ class PythonProgram:
if g: _store(m, o+j, v, src_dtypes[1].scalar())
i += 1
continue
if uop is Ops.AFTER: values[i] = src_values[0]
elif uop in {Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}:
assert isinstance(dtype, PtrDType), dtype
storage_fmt = storage_fmt_for_dtype(dtype.base.scalar())
if storage_fmt is None: raise RuntimeError(f"{dtype=} is not supported")
if u.op is Ops.AFTER: values[u] = src_values[0]
elif u.op in {Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}:
assert isinstance(u.dtype, PtrDType), u.dtype
storage_fmt = storage_fmt_for_dtype(u.dtype.base.scalar())
if storage_fmt is None: raise RuntimeError(f"dtype={u.dtype} is not supported")
if TYPE_CHECKING or sys.version_info < (3, 12): assert storage_fmt != "e"
if uop is Ops.DEFINE_REG:
if u.op is Ops.DEFINE_REG:
# REGs are per thread
values[i] = [memoryview(bytearray(dtype.size*dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)]
values[u] = [memoryview(bytearray(u.dtype.size*u.dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)]
else:
buf = memoryview(bytearray(dtype.size*dtype.itemsize)) if uop is not Ops.PARAM else pbufs.pop(0)
values[i] = [buf.cast(storage_fmt)] * warp_size
elif uop is Ops.DEFINE_VAR:
values[i] = [pvals.pop(0)] * warp_size
elif uop is Ops.SPECIAL:
if arg[0] == 'g': values[i] = [idxs[2-int(arg[-1])]] * warp_size
elif arg[0] == 'l': values[i] = [x[2-int(arg[-1])] for x in warp]
elif uop is Ops.CONST: values[i] = [arg] * warp_size
elif uop is Ops.INDEX:
buf = memoryview(bytearray(u.dtype.size*u.dtype.itemsize)) if u.op is not Ops.PARAM else pbufs.pop(0)
values[u] = [buf.cast(storage_fmt)] * warp_size
elif u.op is Ops.DEFINE_VAR:
values[u] = [pvals.pop(0)] * warp_size
elif u.op is Ops.SPECIAL:
if u.arg[0] == 'g': values[u] = [idxs[2-int(u.arg[-1])]] * warp_size
elif u.arg[0] == 'l': values[u] = [x[2-int(u.arg[-1])] for x in warp]
elif u.op is Ops.CONST: values[u] = [u.arg] * warp_size
elif u.op is Ops.INDEX:
ret:list = []
if isinstance(src_dtypes[0], ImageDType):
assert len(src_values) == 3, "image index must be 3 srcs"
@@ -111,33 +112,33 @@ class PythonProgram:
else:
assert len(src_values) == 2, "non-image index must be 2 srcs"
for m,o in zip(*src_values): ret.append((m,o))
values[i] = ret
elif uop is Ops.CAST and isinstance(dtype, PtrDType):
values[i] = src_values[0]
elif uop is Ops.RANGE:
if i not in values: values[i] = [0] * warp_size
values[u] = ret
elif u.op is Ops.CAST and isinstance(u.dtype, PtrDType):
values[u] = src_values[0]
elif u.op is Ops.RANGE:
if u not in values: values[u] = [0] * warp_size
else:
for j in range(len(values[i])):
values[i][j] += 1
if values[i][0] == src_values[0][0]:
del values[i]
i = loop_ends[i] + 1
for j in range(len(values[u])):
values[u][j] += 1
if values[u][0] == src_values[0][0]:
del values[u]
i = self.loop_ends[u] + 1
continue
elif uop is Ops.STACK: values[i] = src_values
elif uop is Ops.BITCAST: values[i] = [bitcast(x, src_dtypes[0], dtype) for x in src_values[0]]
elif uop is Ops.CAST:
values[i] = [truncate.get(dtype, lambda dt: dt)(dtype.const(x)) for x in src_values[0]]
elif uop is Ops.LOAD:
if dtype.count > 1:
values[i] = [load([src_values[i][j] if i != 0 and src_dtypes[i].count > 1 else src_values[i] \
for i in range(len(src_values))], j, dtype.scalar()) for j in range(dtype.count)]
elif u.op is Ops.STACK: values[u] = src_values
elif u.op is Ops.BITCAST: values[u] = [bitcast(x, src_dtypes[0], u.dtype) for x in src_values[0]]
elif u.op is Ops.CAST:
values[u] = [truncate.get(u.dtype, lambda dt: dt)(u.dtype.const(x)) for x in src_values[0]]
elif u.op is Ops.LOAD:
if u.dtype.count > 1:
values[u] = [load([src_values[k][j] if k != 0 and src_dtypes[k].count > 1 else src_values[k] \
for k in range(len(src_values))], j, u.dtype.scalar()) for j in range(u.dtype.count)]
else:
values[i] = load(src_values, 0, dtype)
elif uop is Ops.GEP: values[i] = src_values[0][get_single_element(arg)]
elif uop is Ops.WMMA:
first_src_dtype = self.uops[srcs[0]][1]
values[u] = load(src_values, 0, u.dtype)
elif u.op is Ops.GEP: values[u] = src_values[0][get_single_element(u.arg)]
elif u.op is Ops.WMMA:
first_src_dtype = u.src[0].dtype
assert isinstance(first_src_dtype, DType) # mypy
dims, dtype_in, device, threads = arg[1], first_src_dtype.scalar(), arg[4], arg[5]
dims, dtype_in, device, threads = u.arg[1], first_src_dtype.scalar(), u.arg[4], u.arg[5]
wmma_helper = functools.partial(generic_wmma_helper, src_values, warp_size)
# TODO: refactor these to a shared TensorCoreLayout
if device == "METAL":
@@ -145,17 +146,17 @@ class PythonProgram:
def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16]
# (i, j), C, D (2 elements on 32 threads): row major same as A/B
def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
values[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
values[u] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
elif device == "AMD" and threads == 64:
def a_elem(x, k, row, goff): return x[k%(dims[2]//4)][goff + (k//(dims[2]//4))*16 + row]
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order
def c_map(lane, elem): return (lane%16, (lane//16)*4 + elem)
values[i] = wmma_helper(64, dims[2], len(src_values[0]), len(src_values[1]), len(src_values[2]), a_elem, b_elem, c_map)
values[u] = wmma_helper(64, dims[2], len(src_values[0]), len(src_values[1]), len(src_values[2]), a_elem, b_elem, c_map)
elif device == "AMD" and len(src_values[0]) == 8: # RDNA4
def a_elem(x, k, row, goff): return x[k - [0, 4, 4, 8][k//4]][goff + row + [0, 16, 0, 16][k//4]]
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff)
def c_map(lane, elem): return (lane%16, (lane//16)*8 + elem)
values[i] = wmma_helper(32, 16, 8, 8, 8, a_elem, b_elem, c_map)
values[u] = wmma_helper(32, 16, 8, 8, 8, a_elem, b_elem, c_map)
elif device == "AMD":
# A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
def a_elem(x, k, row, goff):
@@ -164,7 +165,7 @@ class PythonProgram:
# B (16 elements on 32 threads): row major, lane 16-32 == lane 0-15
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order
def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
values[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
values[u] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
elif device == "CUDA":
# (col, row) given (lane, elem) for C & D (4 elements on 32 threads); shared by all tc shapes with M=16 N=8
def c_map(lane, elem): return (elem%2 + (lane%4)*2, lane//4 + (elem//2)*8)
@@ -172,24 +173,24 @@ class PythonProgram:
if dims == (8,16,16):
def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2 + (k//8)*4][goff + (k//2)%4 + (row%8)*4]
def b_elem(x, col, k, goff): return x[k%2 + (k//8)*2][goff + (k//2)%4 + col*4]
values[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
values[u] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
elif dims == (8,16,32):
def a_elem(x, k, row, goff): return x[k%4 + (row//8)*4 + (k//16)*8][goff + (k//4)%4 + (row%8)*4]
def b_elem(x, col, k, goff): return x[k%4 + (k//16)*4][goff + (k//4)%4 + col*4]
values[i] = wmma_helper(32, 32, 16, 8, 4, a_elem, b_elem, c_map)
values[u] = wmma_helper(32, 32, 16, 8, 4, a_elem, b_elem, c_map)
elif dims == (8,16,8) and dtype_in == dtypes.half:
def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2][goff + k//2 + (row%8)*4]
def b_elem(x, col, k, goff): return x[k%2][goff + k//2 + col*4]
values[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
values[u] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
elif dims == (8,16,8) and dtype_in == dtypes.float:
def a_elem(x, k, row, goff): return x[(k//4)*2 + row//8][goff + k%4 + (row%8)*4]
def b_elem(x, col, k, goff): return x[k//4][goff + k%4 + col*4]
values[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
values[u] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
else: raise NotImplementedError(f"unimplemented tensor core {u.arg}")
elif device == "INTEL":
# A (16 elements on 8 threads)
def a_elem(x, k, row, goff): return x[k%2+row*2][goff+k//2]
@@ -197,17 +198,17 @@ class PythonProgram:
def b_elem(x, col, k, goff): return x[k][goff+col]
# C, D (8 elements on 8 threads)
def c_map(lane, elem): return (lane, elem)
values[i] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map)
values[u] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map)
elif device == "CPU":
def elem(x, col, row, _): return x[col+row][0] # k is always 0
def c_map(lane, elem): return (elem%16, elem//16)
values[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
elif uop in GroupOp.ALU:
assert all_same([len(x) for x in src_values]), f"{[len(x) for x in src_values]} doesn't match on {uop}"
assert all_same([dtype] + src_dtypes) or uop in {*GroupOp.Comparison, Ops.WHERE}, f"dtype mismatch on {uop}"
values[i] = [exec_alu(uop, dtype, p) for p in zip(*src_values)]
assert i in values, (uop, dtype, srcs, arg)
values[u] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
else: raise NotImplementedError(f"unimplemented tensor core {u.arg}")
elif u.op in GroupOp.ALU:
assert all_same([len(x) for x in src_values]), f"{[len(x) for x in src_values]} doesn't match on {u.op}"
assert all_same([u.dtype] + src_dtypes) or u.op in {*GroupOp.Comparison, Ops.WHERE}, f"dtype mismatch on {u.op}"
values[u] = [exec_alu(u.op, u.dtype, p) for p in zip(*src_values)]
assert u in values, u
i += 1
return time.perf_counter() - st
@@ -234,10 +235,7 @@ class PythonRenderer(Renderer):
elif IMAGE and not target.arch: self.target = replace(target, arch="IMAGE_PITCH_ALIGNMENT=1")
else: self.target = target
def render(self, uops:list[UOp]) -> str:
# the value of SPECIAL comes from local/global_size, not form its source
lops = [(u.op, u.dtype, [uops.index(v) for v in u.src if u.op is not Ops.SPECIAL], u.arg) for u in uops]
return base64.b64encode(pickle.dumps(lops)).decode()
def render(self, uops:list[UOp]) -> str: return base64.b64encode(pickle.dumps(uops)).decode()
def supported_dtypes(self): return {d for d in super().supported_dtypes() if d != dtypes.half or sys.version_info >= (3, 12)}