mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
Merge branch 'master' into shrink_in_render
This commit is contained in:
97
.github/workflows/test.yml
vendored
97
.github/workflows/test.yml
vendored
@@ -404,8 +404,6 @@ jobs:
|
||||
- name: Test openpilot model kernel count and gate usage
|
||||
run: |
|
||||
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1468 ALLOWED_GATED_READ_IMAGE=18 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
|
||||
- name: Test openpilot CL compile fp16
|
||||
run: FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
|
||||
- name: Test openpilot CL compile fp32 (test correctness)
|
||||
run: |
|
||||
DEV=CL IMAGE=1 SELFTEST=1 python examples/openpilot/compile3.py https://github.com/haraschax/filedump/raw/refs/heads/master/driving_vision_fp32.onnx
|
||||
@@ -558,8 +556,7 @@ jobs:
|
||||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: dsp-minimal
|
||||
deps: testing_unit
|
||||
pydeps: "onnx==1.18.0 onnxruntime ml_dtypes"
|
||||
deps: testing
|
||||
llvm: "true"
|
||||
qemu: "true"
|
||||
- name: Set MOCKDSP env
|
||||
@@ -817,74 +814,42 @@ jobs:
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
osxwebgpu:
|
||||
name: MacOS (WebGPU)
|
||||
runs-on: macos-14
|
||||
timeout-minutes: 10
|
||||
testmacos:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
dev:
|
||||
- 'CPU:CLANG'
|
||||
- 'CPU:LLVM'
|
||||
- 'CPU:LVP'
|
||||
- 'METAL'
|
||||
- 'WEBGPU'
|
||||
|
||||
name: MacOS (DEV=${{ matrix.dev }})
|
||||
runs-on: macos-15
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v6
|
||||
- name: Setup Environment
|
||||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: osx-webgpu
|
||||
deps: testing
|
||||
webgpu: 'true'
|
||||
- name: Build WEBGPU Efficientnet
|
||||
run: DEV=WEBGPU WEBGPU_BACKEND="WGPUBackendType_Metal" python3 -m examples.compile_efficientnet
|
||||
- name: Run selected webgpu tests
|
||||
run: DEV=WEBGPU WEBGPU_BACKEND="WGPUBackendType_Metal" python3 -m pytest -n=auto test/backend --durations=20
|
||||
#- name: Clean npm cache
|
||||
# run: npm cache clean --force
|
||||
#- name: Install Puppeteer
|
||||
# run: npm install puppeteer
|
||||
# this is also flaky
|
||||
#- name: Run WEBGPU Efficientnet
|
||||
# run: node test/web/test_webgpu.js
|
||||
# this is flaky
|
||||
#- name: Run VIZ tests as external package
|
||||
# run: |
|
||||
# mkdir $GITHUB_WORKSPACE/test_dir
|
||||
# cd $GITHUB_WORKSPACE/test_dir
|
||||
# python -m venv venv
|
||||
# source venv/bin/activate
|
||||
# pip install $GITHUB_WORKSPACE
|
||||
# cp $GITHUB_WORKSPACE/test/web/test_viz.js .
|
||||
# node test_viz.js
|
||||
- name: Test ONNX Runner (WEBGPU)
|
||||
run: DEV=WEBGPU python3 test/external/external_test_onnx_runner.py
|
||||
|
||||
osxtests:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
backend: [metal, llvm, cpu, lvp]
|
||||
name: MacOS (${{ matrix.backend }})
|
||||
runs-on: macos-15
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v6
|
||||
- name: Setup Environment
|
||||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: macos-${{ matrix.backend }}-minimal
|
||||
deps: testing_unit
|
||||
llvm: ${{ matrix.backend == 'llvm' || matrix.backend == 'lvp' }}
|
||||
mesa: ${{ matrix.backend == 'lvp' && 'cpu' }}
|
||||
- name: Set env
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'DEV=CPU:LLVM' || matrix.backend == 'cpu' && 'DEV=CPU\nCPU_COUNT=2' || matrix.backend == 'metal' && 'DEV=METAL' || matrix.backend == 'lvp' && 'DEV=CPU:LVP' }}" >> $GITHUB_ENV
|
||||
- name: Check Device.DEFAULT and print some source
|
||||
run: |
|
||||
python -c "from tinygrad import Device; assert Device.DEFAULT == {'LLVM':'CPU','LVP':'CPU'}.get(x:='${{ matrix.backend }}'.upper(), x), Device.DEFAULT"
|
||||
DEBUG=4 python3 test/test_tiny.py TestTiny.test_plus
|
||||
- name: Run pytest (${{ matrix.backend }})
|
||||
run: python3 -m pytest -n=auto test/backend --durations=20
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
- name: Run macOS-specific unit test
|
||||
if: matrix.backend == 'llvm'
|
||||
run: python3 -m pytest test/unit/test_disk_tensor.py::TestDiskTensor::test_copy_to_cpu_not_truncated test/unit/test_cpu.py
|
||||
key: macos-${{ matrix.dev }}
|
||||
deps: testing_unit
|
||||
python-version: '3.12'
|
||||
llvm: ${{ contains(matrix.dev, 'LLVM') || contains(matrix.dev, 'LVP') }}
|
||||
mesa: ${{ contains(matrix.dev, 'LVP') && 'cpu' || 'false' }}
|
||||
webgpu: ${{ matrix.dev == 'WEBGPU' }}
|
||||
- name: Set env
|
||||
run: printf "DEV=${{ matrix.dev }}${{ matrix.dev == 'CPU:CLANG' && '\nCPU_COUNT=2' || '' }}" >> $GITHUB_ENV
|
||||
- name: Check Device.DEFAULT and print some source
|
||||
run: |
|
||||
python -c "from tinygrad import Device; from tinygrad.helpers import Target; assert Device.DEFAULT == Target.parse('${{ matrix.dev }}').device"
|
||||
DEBUG=4 python test/test_tiny.py TestTiny.test_plus
|
||||
- name: Run backend tests
|
||||
run: python -m pytest -n=auto test/backend --durations=20
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
# ****** Windows Tests ******
|
||||
|
||||
|
||||
@@ -1419,10 +1419,7 @@ def train_llama3():
|
||||
|
||||
for p in optim.params:
|
||||
grad_dtype = dtypes.bfloat16 if p.dtype == FP8_DTYPE else p.dtype
|
||||
if isinstance(p.device, tuple) and p.uop.axis is not None:
|
||||
p.grad = Tensor.zeros(p.shape, dtype=grad_dtype, device=p.device[0]).shard_(p.device, axis=p.uop.axis).contiguous()
|
||||
else:
|
||||
p.grad = Tensor.zeros(p.shape, dtype=grad_dtype, device=p.device).contiguous()
|
||||
p.grad = p.zeros_like(dtype=grad_dtype).contiguous()
|
||||
grads = [p.grad for p in optim.params]
|
||||
|
||||
scheduler = CosineAnnealingLRWithWarmup(optim, opt_base_learning_rate, opt_end_learning_rate, opt_learning_rate_warmup_steps, opt_learning_rate_decay_steps)
|
||||
|
||||
@@ -222,14 +222,19 @@ class FlatTransformer:
|
||||
for v in get_parameters(self): v.shard_(device, axis=None)
|
||||
else:
|
||||
# flat per-layer weights: axis 0 is n_layers, so shard axes are +1 vs per-layer Transformer
|
||||
self.wqkv.shard_(device, axis=1).realize() # (n_layers, out, dim) shard out
|
||||
self.wo.shard_(device, axis=2).realize() # (n_layers, dim, in) shard in
|
||||
def _shard_fp8(name:str, axis:int):
|
||||
getattr(self, name).shard_(device, axis=axis)
|
||||
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].to(device).contiguous().is_param_(False)
|
||||
self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].to(device).contiguous().is_param_(False)
|
||||
Tensor.realize(getattr(self, name), self._fp8_inv_scale[name], self._fp8_next_inv_scale[name])
|
||||
_shard_fp8("wqkv", 1) # (n_layers, out, dim) shard out
|
||||
_shard_fp8("wo", 2) # (n_layers, dim, in) shard in
|
||||
if SPLIT_W13:
|
||||
self.w1.shard_(device, axis=1).realize()
|
||||
self.w3.shard_(device, axis=1).realize()
|
||||
_shard_fp8("w1", 1)
|
||||
_shard_fp8("w3", 1)
|
||||
else:
|
||||
self.w13.shard_(device, axis=1).realize() # (n_layers, hidden*2, dim) shard out
|
||||
self.w2.shard_(device, axis=2).realize() # (n_layers, dim, hidden) shard in
|
||||
_shard_fp8("w13", 1) # (n_layers, hidden*2, dim) shard out
|
||||
_shard_fp8("w2", 2) # (n_layers, dim, hidden) shard in
|
||||
self.attention_norm.shard_(device, axis=None).realize()
|
||||
self.ffn_norm.shard_(device, axis=None).realize()
|
||||
self.norm.weight.shard_(device, axis=None).realize()
|
||||
@@ -240,10 +245,6 @@ class FlatTransformer:
|
||||
for name in amax_dict:
|
||||
for i in range(len(amax_dict[name])):
|
||||
amax_dict[name][i] = amax_dict[name][i].to(device).contiguous().is_param_(False)
|
||||
for name in self._fp8_inv_scale:
|
||||
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].to(device).contiguous().is_param_(False)
|
||||
for name in self._fp8_next_inv_scale:
|
||||
self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].to(device).contiguous().is_param_(False)
|
||||
|
||||
def __call__(self, tokens:Tensor, save:bool=True):
|
||||
h = self.tok_embeddings(tokens)
|
||||
@@ -325,11 +326,7 @@ if __name__ == "__main__":
|
||||
|
||||
# preallocate all the grad buffers and zero them out
|
||||
grad_dtype = lambda x: dtypes.bfloat16 if x.dtype in dtypes.fp8s else x.dtype
|
||||
def _make_grad(x):
|
||||
if isinstance(x.device, tuple) and x.uop.axis is not None:
|
||||
return Tensor.zeros(x.shape, dtype=grad_dtype(x), device=x.device[0]).shard_(x.device, axis=x.uop.axis).contiguous()
|
||||
return Tensor.zeros(x.shape, dtype=grad_dtype(x), device=x.device).contiguous()
|
||||
grads = {x:_make_grad(x) for x in state.values() if x.is_param}
|
||||
grads = {x:x.zeros_like(dtype=grad_dtype(x)).contiguous() for x in state.values() if x.is_param}
|
||||
|
||||
fp8_amax = [t for ts in model._fp8_amax.values() for t in ts]
|
||||
fp8_grad_amax = [t for ts in model._fp8_grad_amax.values() for t in ts]
|
||||
|
||||
@@ -240,9 +240,9 @@ class TestTorchBackend(unittest.TestCase):
|
||||
np.testing.assert_equal(result.cpu().numpy(), [3., 3., 2.])
|
||||
|
||||
def test_mnist_index(self):
|
||||
# from tinygrad.nn.datasets import mnist
|
||||
X_train, Y_train = Tensor.randint(60000, 1, 28, 28, dtype='uchar').realize(), Tensor.randint(60000, dtype='uchar').realize()
|
||||
GlobalCounters.reset()
|
||||
from tinygrad.nn.datasets import mnist
|
||||
X_train, Y_train, _, _ = mnist()
|
||||
X_train = torch.tensor(X_train.float().numpy(), device=device)
|
||||
Y_train = torch.tensor(Y_train.cast('int64').numpy(), device=device)
|
||||
samples = torch.randint(0, X_train.shape[0], (32,))
|
||||
|
||||
@@ -125,8 +125,8 @@ class TestIndexing(unittest.TestCase):
|
||||
def test_index_mnist(self, noopt=1, op_limit=512*784*13, split_reduceop=0):
|
||||
# WEBGPU generates more ops due to bitpacking of < 4-byte dtypes
|
||||
if Device.DEFAULT == "WEBGPU": op_limit *= 15
|
||||
from tinygrad.nn.datasets import mnist
|
||||
X_train, Y_train, _, _ = mnist()
|
||||
# from tinygrad.nn.datasets import mnist
|
||||
X_train, Y_train = Tensor.randint(DSET, 1, 28, 28, dtype='uchar').realize(), Tensor.randint(DSET, dtype='uchar').realize()
|
||||
with Context(NOOPT=noopt, SPLIT_REDUCEOP=split_reduceop):
|
||||
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0]).realize()
|
||||
GlobalCounters.reset()
|
||||
|
||||
@@ -1059,9 +1059,9 @@ class TestSchedule(unittest.TestCase):
|
||||
self.assertEqual(b.tolist(), [False, False])
|
||||
|
||||
def test_mnist_val(self):
|
||||
from tinygrad.nn.datasets import mnist
|
||||
# from tinygrad.nn.datasets import mnist
|
||||
import torch
|
||||
_, Y_train, _, _ = mnist()
|
||||
Y_train = Tensor.randint(60000, dtype='uchar').realize()
|
||||
samples = Tensor.randint(BS:=getenv("BS", 512), high=cast(int,Y_train.shape[-1])).realize()
|
||||
yt = Tensor.randn(BS, 10).realize()
|
||||
loss = yt.sparse_categorical_crossentropy(Y_train[samples])
|
||||
|
||||
@@ -1000,7 +1000,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
if align_corners: return Tensor.linspace(-1, 1, steps, device=theta.device)
|
||||
return Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device)
|
||||
grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims))
|
||||
base_grid = Tensor.stack(*reversed(grids), Tensor.ones_like(grids[0], device=theta.device), dim=-1)
|
||||
base_grid = Tensor.stack(*reversed(grids), grids[0].const_like(1), dim=-1)
|
||||
base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1)
|
||||
return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1)
|
||||
|
||||
|
||||
@@ -41,41 +41,42 @@ def generic_wmma_helper(inp, warp_size, WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_
|
||||
|
||||
class PythonProgram:
|
||||
def __init__(self, name:str, lib:bytes, **kwargs):
|
||||
self.uops: list[tuple[Ops, DType, list[int], Any]] = pickle.loads(lib)
|
||||
self.uops: list[UOp] = pickle.loads(lib)
|
||||
self.uop_to_index: dict[UOp, int] = {u:i for i,u in enumerate(self.uops)}
|
||||
self.loop_ends: dict[UOp, int] = {u.src[1]:i for i, u in enumerate(self.uops) if u.op == Ops.END}
|
||||
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw):
|
||||
st = time.perf_counter()
|
||||
warp = list(itertools.product(*[range(x) for x in local_size[::-1]]))
|
||||
warp_size = len(warp)
|
||||
void_ops = {Ops.END, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.GROUP, Ops.STORE}
|
||||
loop_ends: dict[int, int] = {srcs[1]:i for i, (uop, _, srcs, _) in enumerate(self.uops) if uop == Ops.END}
|
||||
for idxs in itertools.product(*[range(x) for x in global_size[::-1]]):
|
||||
values: dict[int, Any] = {}
|
||||
values: dict[UOp, Any] = {}
|
||||
pbufs: list[memoryview] = list(bufs)
|
||||
pvals: list[int] = list(vals)
|
||||
exec_masks = [[True] * warp_size]
|
||||
i = 0
|
||||
while i < len(self.uops):
|
||||
uop, dtype, srcs, arg = self.uops[i]
|
||||
src_values = [values[v] for v in srcs if self.uops[v][0] not in void_ops]
|
||||
src_dtypes = [self.uops[v][1] for v in srcs if self.uops[v][0] not in void_ops]
|
||||
if getenv("TRACE"): print(i, uop, dtype, arg, src_values, src_dtypes)
|
||||
if uop is Ops.END:
|
||||
i = srcs[1]
|
||||
u = self.uops[i]
|
||||
src_values = [values[v] for v in u.src if v.op not in void_ops]
|
||||
src_dtypes = [v.dtype for v in u.src if v.op not in void_ops]
|
||||
if getenv("TRACE"): print(i, u.op, u.dtype, u.arg, src_values, src_dtypes)
|
||||
if u.op is Ops.END:
|
||||
i = self.uop_to_index[u.src[1]]
|
||||
continue
|
||||
if uop is Ops.IF:
|
||||
if u.op is Ops.IF:
|
||||
exec_masks.append([x and y for x,y in zip(exec_masks[-1], src_values[0])])
|
||||
i += 1
|
||||
continue
|
||||
if uop is Ops.ENDIF:
|
||||
if u.op is Ops.ENDIF:
|
||||
exec_masks.pop()
|
||||
i += 1
|
||||
continue
|
||||
if uop in (Ops.BARRIER, Ops.SINK, Ops.NOOP, Ops.GROUP):
|
||||
if u.op in (Ops.BARRIER, Ops.SINK, Ops.NOOP, Ops.GROUP):
|
||||
# in the python emulator, the warp is always in sync
|
||||
i += 1
|
||||
continue
|
||||
assert dtype is not None, f"{uop} is missing a dtype"
|
||||
if uop is Ops.STORE:
|
||||
assert u.dtype is not None, f"{u.op} is missing a dtype"
|
||||
if u.op is Ops.STORE:
|
||||
assert len(src_values) == 2, f"STORE must be lowered to 2 srcs, got {len(src_values)}"
|
||||
store_gate = exec_masks[-1]
|
||||
for j,val in enumerate(src_values[1] if src_dtypes[1].count > 1 else [src_values[1]]):
|
||||
@@ -83,25 +84,25 @@ class PythonProgram:
|
||||
if g: _store(m, o+j, v, src_dtypes[1].scalar())
|
||||
i += 1
|
||||
continue
|
||||
if uop is Ops.AFTER: values[i] = src_values[0]
|
||||
elif uop in {Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}:
|
||||
assert isinstance(dtype, PtrDType), dtype
|
||||
storage_fmt = storage_fmt_for_dtype(dtype.base.scalar())
|
||||
if storage_fmt is None: raise RuntimeError(f"{dtype=} is not supported")
|
||||
if u.op is Ops.AFTER: values[u] = src_values[0]
|
||||
elif u.op in {Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}:
|
||||
assert isinstance(u.dtype, PtrDType), u.dtype
|
||||
storage_fmt = storage_fmt_for_dtype(u.dtype.base.scalar())
|
||||
if storage_fmt is None: raise RuntimeError(f"dtype={u.dtype} is not supported")
|
||||
if TYPE_CHECKING or sys.version_info < (3, 12): assert storage_fmt != "e"
|
||||
if uop is Ops.DEFINE_REG:
|
||||
if u.op is Ops.DEFINE_REG:
|
||||
# REGs are per thread
|
||||
values[i] = [memoryview(bytearray(dtype.size*dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)]
|
||||
values[u] = [memoryview(bytearray(u.dtype.size*u.dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)]
|
||||
else:
|
||||
buf = memoryview(bytearray(dtype.size*dtype.itemsize)) if uop is not Ops.PARAM else pbufs.pop(0)
|
||||
values[i] = [buf.cast(storage_fmt)] * warp_size
|
||||
elif uop is Ops.DEFINE_VAR:
|
||||
values[i] = [pvals.pop(0)] * warp_size
|
||||
elif uop is Ops.SPECIAL:
|
||||
if arg[0] == 'g': values[i] = [idxs[2-int(arg[-1])]] * warp_size
|
||||
elif arg[0] == 'l': values[i] = [x[2-int(arg[-1])] for x in warp]
|
||||
elif uop is Ops.CONST: values[i] = [arg] * warp_size
|
||||
elif uop is Ops.INDEX:
|
||||
buf = memoryview(bytearray(u.dtype.size*u.dtype.itemsize)) if u.op is not Ops.PARAM else pbufs.pop(0)
|
||||
values[u] = [buf.cast(storage_fmt)] * warp_size
|
||||
elif u.op is Ops.DEFINE_VAR:
|
||||
values[u] = [pvals.pop(0)] * warp_size
|
||||
elif u.op is Ops.SPECIAL:
|
||||
if u.arg[0] == 'g': values[u] = [idxs[2-int(u.arg[-1])]] * warp_size
|
||||
elif u.arg[0] == 'l': values[u] = [x[2-int(u.arg[-1])] for x in warp]
|
||||
elif u.op is Ops.CONST: values[u] = [u.arg] * warp_size
|
||||
elif u.op is Ops.INDEX:
|
||||
ret:list = []
|
||||
if isinstance(src_dtypes[0], ImageDType):
|
||||
assert len(src_values) == 3, "image index must be 3 srcs"
|
||||
@@ -111,33 +112,33 @@ class PythonProgram:
|
||||
else:
|
||||
assert len(src_values) == 2, "non-image index must be 2 srcs"
|
||||
for m,o in zip(*src_values): ret.append((m,o))
|
||||
values[i] = ret
|
||||
elif uop is Ops.CAST and isinstance(dtype, PtrDType):
|
||||
values[i] = src_values[0]
|
||||
elif uop is Ops.RANGE:
|
||||
if i not in values: values[i] = [0] * warp_size
|
||||
values[u] = ret
|
||||
elif u.op is Ops.CAST and isinstance(u.dtype, PtrDType):
|
||||
values[u] = src_values[0]
|
||||
elif u.op is Ops.RANGE:
|
||||
if u not in values: values[u] = [0] * warp_size
|
||||
else:
|
||||
for j in range(len(values[i])):
|
||||
values[i][j] += 1
|
||||
if values[i][0] == src_values[0][0]:
|
||||
del values[i]
|
||||
i = loop_ends[i] + 1
|
||||
for j in range(len(values[u])):
|
||||
values[u][j] += 1
|
||||
if values[u][0] == src_values[0][0]:
|
||||
del values[u]
|
||||
i = self.loop_ends[u] + 1
|
||||
continue
|
||||
elif uop is Ops.STACK: values[i] = src_values
|
||||
elif uop is Ops.BITCAST: values[i] = [bitcast(x, src_dtypes[0], dtype) for x in src_values[0]]
|
||||
elif uop is Ops.CAST:
|
||||
values[i] = [truncate.get(dtype, lambda dt: dt)(dtype.const(x)) for x in src_values[0]]
|
||||
elif uop is Ops.LOAD:
|
||||
if dtype.count > 1:
|
||||
values[i] = [load([src_values[i][j] if i != 0 and src_dtypes[i].count > 1 else src_values[i] \
|
||||
for i in range(len(src_values))], j, dtype.scalar()) for j in range(dtype.count)]
|
||||
elif u.op is Ops.STACK: values[u] = src_values
|
||||
elif u.op is Ops.BITCAST: values[u] = [bitcast(x, src_dtypes[0], u.dtype) for x in src_values[0]]
|
||||
elif u.op is Ops.CAST:
|
||||
values[u] = [truncate.get(u.dtype, lambda dt: dt)(u.dtype.const(x)) for x in src_values[0]]
|
||||
elif u.op is Ops.LOAD:
|
||||
if u.dtype.count > 1:
|
||||
values[u] = [load([src_values[k][j] if k != 0 and src_dtypes[k].count > 1 else src_values[k] \
|
||||
for k in range(len(src_values))], j, u.dtype.scalar()) for j in range(u.dtype.count)]
|
||||
else:
|
||||
values[i] = load(src_values, 0, dtype)
|
||||
elif uop is Ops.GEP: values[i] = src_values[0][get_single_element(arg)]
|
||||
elif uop is Ops.WMMA:
|
||||
first_src_dtype = self.uops[srcs[0]][1]
|
||||
values[u] = load(src_values, 0, u.dtype)
|
||||
elif u.op is Ops.GEP: values[u] = src_values[0][get_single_element(u.arg)]
|
||||
elif u.op is Ops.WMMA:
|
||||
first_src_dtype = u.src[0].dtype
|
||||
assert isinstance(first_src_dtype, DType) # mypy
|
||||
dims, dtype_in, device, threads = arg[1], first_src_dtype.scalar(), arg[4], arg[5]
|
||||
dims, dtype_in, device, threads = u.arg[1], first_src_dtype.scalar(), u.arg[4], u.arg[5]
|
||||
wmma_helper = functools.partial(generic_wmma_helper, src_values, warp_size)
|
||||
# TODO: refactor these to a shared TensorCoreLayout
|
||||
if device == "METAL":
|
||||
@@ -145,17 +146,17 @@ class PythonProgram:
|
||||
def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16]
|
||||
# (i, j), C, D (2 elements on 32 threads): row major same as A/B
|
||||
def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
|
||||
values[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
|
||||
values[u] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
|
||||
elif device == "AMD" and threads == 64:
|
||||
def a_elem(x, k, row, goff): return x[k%(dims[2]//4)][goff + (k//(dims[2]//4))*16 + row]
|
||||
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order
|
||||
def c_map(lane, elem): return (lane%16, (lane//16)*4 + elem)
|
||||
values[i] = wmma_helper(64, dims[2], len(src_values[0]), len(src_values[1]), len(src_values[2]), a_elem, b_elem, c_map)
|
||||
values[u] = wmma_helper(64, dims[2], len(src_values[0]), len(src_values[1]), len(src_values[2]), a_elem, b_elem, c_map)
|
||||
elif device == "AMD" and len(src_values[0]) == 8: # RDNA4
|
||||
def a_elem(x, k, row, goff): return x[k - [0, 4, 4, 8][k//4]][goff + row + [0, 16, 0, 16][k//4]]
|
||||
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff)
|
||||
def c_map(lane, elem): return (lane%16, (lane//16)*8 + elem)
|
||||
values[i] = wmma_helper(32, 16, 8, 8, 8, a_elem, b_elem, c_map)
|
||||
values[u] = wmma_helper(32, 16, 8, 8, 8, a_elem, b_elem, c_map)
|
||||
elif device == "AMD":
|
||||
# A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
|
||||
def a_elem(x, k, row, goff):
|
||||
@@ -164,7 +165,7 @@ class PythonProgram:
|
||||
# B (16 elements on 32 threads): row major, lane 16-32 == lane 0-15
|
||||
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order
|
||||
def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
|
||||
values[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
|
||||
values[u] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
|
||||
elif device == "CUDA":
|
||||
# (col, row) given (lane, elem) for C & D (4 elements on 32 threads); shared by all tc shapes with M=16 N=8
|
||||
def c_map(lane, elem): return (elem%2 + (lane%4)*2, lane//4 + (elem//2)*8)
|
||||
@@ -172,24 +173,24 @@ class PythonProgram:
|
||||
if dims == (8,16,16):
|
||||
def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2 + (k//8)*4][goff + (k//2)%4 + (row%8)*4]
|
||||
def b_elem(x, col, k, goff): return x[k%2 + (k//8)*2][goff + (k//2)%4 + col*4]
|
||||
values[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
|
||||
values[u] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
|
||||
|
||||
elif dims == (8,16,32):
|
||||
def a_elem(x, k, row, goff): return x[k%4 + (row//8)*4 + (k//16)*8][goff + (k//4)%4 + (row%8)*4]
|
||||
def b_elem(x, col, k, goff): return x[k%4 + (k//16)*4][goff + (k//4)%4 + col*4]
|
||||
values[i] = wmma_helper(32, 32, 16, 8, 4, a_elem, b_elem, c_map)
|
||||
values[u] = wmma_helper(32, 32, 16, 8, 4, a_elem, b_elem, c_map)
|
||||
|
||||
elif dims == (8,16,8) and dtype_in == dtypes.half:
|
||||
def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2][goff + k//2 + (row%8)*4]
|
||||
def b_elem(x, col, k, goff): return x[k%2][goff + k//2 + col*4]
|
||||
values[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
|
||||
values[u] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
|
||||
|
||||
elif dims == (8,16,8) and dtype_in == dtypes.float:
|
||||
def a_elem(x, k, row, goff): return x[(k//4)*2 + row//8][goff + k%4 + (row%8)*4]
|
||||
def b_elem(x, col, k, goff): return x[k//4][goff + k%4 + col*4]
|
||||
values[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
|
||||
values[u] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
|
||||
|
||||
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
|
||||
else: raise NotImplementedError(f"unimplemented tensor core {u.arg}")
|
||||
elif device == "INTEL":
|
||||
# A (16 elements on 8 threads)
|
||||
def a_elem(x, k, row, goff): return x[k%2+row*2][goff+k//2]
|
||||
@@ -197,17 +198,17 @@ class PythonProgram:
|
||||
def b_elem(x, col, k, goff): return x[k][goff+col]
|
||||
# C, D (8 elements on 8 threads)
|
||||
def c_map(lane, elem): return (lane, elem)
|
||||
values[i] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map)
|
||||
values[u] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map)
|
||||
elif device == "CPU":
|
||||
def elem(x, col, row, _): return x[col+row][0] # k is always 0
|
||||
def c_map(lane, elem): return (elem%16, elem//16)
|
||||
values[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
|
||||
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
|
||||
elif uop in GroupOp.ALU:
|
||||
assert all_same([len(x) for x in src_values]), f"{[len(x) for x in src_values]} doesn't match on {uop}"
|
||||
assert all_same([dtype] + src_dtypes) or uop in {*GroupOp.Comparison, Ops.WHERE}, f"dtype mismatch on {uop}"
|
||||
values[i] = [exec_alu(uop, dtype, p) for p in zip(*src_values)]
|
||||
assert i in values, (uop, dtype, srcs, arg)
|
||||
values[u] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
|
||||
else: raise NotImplementedError(f"unimplemented tensor core {u.arg}")
|
||||
elif u.op in GroupOp.ALU:
|
||||
assert all_same([len(x) for x in src_values]), f"{[len(x) for x in src_values]} doesn't match on {u.op}"
|
||||
assert all_same([u.dtype] + src_dtypes) or u.op in {*GroupOp.Comparison, Ops.WHERE}, f"dtype mismatch on {u.op}"
|
||||
values[u] = [exec_alu(u.op, u.dtype, p) for p in zip(*src_values)]
|
||||
assert u in values, u
|
||||
i += 1
|
||||
return time.perf_counter() - st
|
||||
|
||||
@@ -234,10 +235,7 @@ class PythonRenderer(Renderer):
|
||||
elif IMAGE and not target.arch: self.target = replace(target, arch="IMAGE_PITCH_ALIGNMENT=1")
|
||||
else: self.target = target
|
||||
|
||||
def render(self, uops:list[UOp]) -> str:
|
||||
# the value of SPECIAL comes from local/global_size, not form its source
|
||||
lops = [(u.op, u.dtype, [uops.index(v) for v in u.src if u.op is not Ops.SPECIAL], u.arg) for u in uops]
|
||||
return base64.b64encode(pickle.dumps(lops)).decode()
|
||||
def render(self, uops:list[UOp]) -> str: return base64.b64encode(pickle.dumps(uops)).decode()
|
||||
|
||||
def supported_dtypes(self): return {d for d in super().supported_dtypes() if d != dtypes.half or sys.version_info >= (3, 12)}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user