mirror of
https://github.com/firestar5683/StarPilot.git
synced 2026-07-02 12:02:09 +08:00
September 27th, 2025 Update
This commit is contained in:
@@ -4,25 +4,35 @@
|
||||
# this is the (living) definition of uops
|
||||
from typing import Any, TYPE_CHECKING
|
||||
import pickle, base64, itertools, time, struct, sys
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate, float_to_bf16
|
||||
from tinygrad.helpers import all_same, getenv, flatten, get_single_element
|
||||
from tinygrad.device import Compiled, Compiler, Allocator
|
||||
from tinygrad.opt import tc
|
||||
from tinygrad.uop.ops import exec_alu, Ops, UOp, GroupOp
|
||||
from tinygrad.codegen.opt import tc
|
||||
from tinygrad.uop.ops import exec_alu, python_alu, Ops, UOp, GroupOp
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
def _load(m, i):
|
||||
def storage_fmt_for_dtype(dtype: DType): return 'H' if dtype == dtypes.bfloat16 else dtype.fmt
|
||||
|
||||
def to_storage_scalar(x, dtype: DType):
|
||||
if dtype == dtypes.bfloat16: return (struct.unpack('I', struct.pack('f', float_to_bf16(x)))[0] >> 16) & 0xFFFF
|
||||
return x
|
||||
|
||||
def from_storage_scalar(x, dtype: DType):
|
||||
if dtype == dtypes.bfloat16: return struct.unpack('f', struct.pack('I', (x & 0xFFFF) << 16))[0]
|
||||
return x
|
||||
|
||||
def _load(m, i, dtype: DType):
|
||||
if i is None: return 0.0
|
||||
if i < 0 or i >= len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
|
||||
return m[i]
|
||||
return from_storage_scalar(m[i], dtype)
|
||||
|
||||
def load(inp, j=0):
|
||||
if len(inp) == 2: return [_load(m, x+j if x is not None else None) if gate else default for (m,x,gate),default in zip(*inp)]
|
||||
return [_load(m, x+j if x is not None else None) for m,x,_ in inp[0]]
|
||||
def load(inp, j, dtype: DType):
|
||||
if len(inp) == 2: return [_load(m, x+j if x is not None else None, dtype) if gate else default for (m,x,gate),default in zip(*inp)]
|
||||
return [_load(m, x+j if x is not None else None, dtype) for m,x,_ in inp[0]]
|
||||
|
||||
def _store(m, i, v):
|
||||
def _store(m, i, v, dtype: DType):
|
||||
if i < 0 or i >= len(m): raise IndexError(f"store out of bounds, size is {len(m)}, access is {i}, value is {v}")
|
||||
m[i] = v
|
||||
m[i] = to_storage_scalar(v, dtype)
|
||||
|
||||
class PythonProgram:
|
||||
def __init__(self, name:str, lib:bytes):
|
||||
@@ -57,18 +67,20 @@ class PythonProgram:
|
||||
if uop is Ops.STORE:
|
||||
for j,val in enumerate(inp[1] if dtp[1].count > 1 else [inp[1]]):
|
||||
for (m,o,g),v in zip(inp[0], val):
|
||||
if g: _store(m, o+j, v)
|
||||
if g: _store(m, o+j, v, dtp[1].scalar())
|
||||
i += 1
|
||||
continue
|
||||
if uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}:
|
||||
assert dtype.fmt is not None and isinstance(dtype, PtrDType)
|
||||
if TYPE_CHECKING or sys.version_info < (3, 12): assert dtype.fmt != "e"
|
||||
assert isinstance(dtype, PtrDType), dtype
|
||||
storage_fmt = storage_fmt_for_dtype(dtype.base.scalar())
|
||||
if storage_fmt is None: raise RuntimeError(f"{dtype=} is not supported")
|
||||
if TYPE_CHECKING or sys.version_info < (3, 12): assert storage_fmt != "e"
|
||||
if uop is Ops.DEFINE_REG:
|
||||
# REGs are per thread
|
||||
ul[i] = [memoryview(bytearray(dtype.size*dtype.itemsize)).cast(dtype.fmt) for _ in range(warp_size)]
|
||||
ul[i] = [memoryview(bytearray(dtype.size*dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)]
|
||||
else:
|
||||
buf = memoryview(bytearray(dtype.size*dtype.itemsize)) if uop is not Ops.DEFINE_GLOBAL else pbufs.pop(0)
|
||||
ul[i] = [buf.cast(dtype.fmt)] * warp_size
|
||||
ul[i] = [buf.cast(storage_fmt)] * warp_size
|
||||
elif uop is Ops.DEFINE_VAR:
|
||||
ul[i] = [pvals.pop(0)] * warp_size
|
||||
elif uop is Ops.SPECIAL:
|
||||
@@ -97,16 +109,17 @@ class PythonProgram:
|
||||
continue
|
||||
elif uop is Ops.VECTORIZE: ul[i] = inp
|
||||
elif uop is Ops.BITCAST:
|
||||
assert dtp[0].fmt and dtype.fmt
|
||||
pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
|
||||
ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
|
||||
packed = struct.pack(str(warp_size) + storage_fmt_for_dtype(dtp[0].scalar()), *[to_storage_scalar(x, dtp[0].scalar()) for x in inp[0]])
|
||||
ul[i] = list(struct.unpack(str(warp_size) + storage_fmt_for_dtype(dtype.scalar()), packed))
|
||||
ul[i] = [from_storage_scalar(x, dtype.scalar()) for x in ul[i]]
|
||||
elif uop is Ops.CAST:
|
||||
ul[i] = [truncate.get(dtype, lambda dt: dt)(dtypes.as_const(x, dtype)) for x in inp[0]]
|
||||
elif uop is Ops.LOAD:
|
||||
if dtype.count > 1:
|
||||
ul[i] = [load([inp[i][j] if i != 0 and dtp[i].count > 1 else inp[i] for i in range(len(inp))], j) for j in range(dtype.count)]
|
||||
ul[i] = [load([inp[i][j] if i != 0 and dtp[i].count > 1 else inp[i] for i in range(len(inp))], j, dtype.scalar()) \
|
||||
for j in range(dtype.count)]
|
||||
else:
|
||||
ul[i] = load(inp)
|
||||
ul[i] = load(inp, 0, dtype)
|
||||
elif uop is Ops.GEP: ul[i] = inp[0][get_single_element(arg)]
|
||||
elif uop is Ops.WMMA:
|
||||
# here are the models for the WMMA instruction on the different hardware
|
||||
@@ -123,24 +136,27 @@ class PythonProgram:
|
||||
out[elem_idx][goff+lane_id] += sum(a_elem(inp[0], _k, c_j, goff) * b_elem(inp[1], c_i, _k, goff) for _k in range(K))
|
||||
return out
|
||||
|
||||
first_src_dtype = self.uops[idp[0]][1]
|
||||
assert isinstance(first_src_dtype, DType) # mypy
|
||||
dims, dtype_in, device, threads = arg[1], first_src_dtype.scalar(), arg[4], arg[5]
|
||||
# TODO: refactor these to a shared TensorCoreLayout in kernel.py
|
||||
if arg[4] == "METAL":
|
||||
if device == "METAL":
|
||||
# A (2 elements on 32 threads): row major
|
||||
def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16]
|
||||
# (i, j), C, D (2 elements on 32 threads): row major same as A/B
|
||||
def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
|
||||
ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
|
||||
elif arg[4] == "AMD" and arg[5] == 64:
|
||||
elif device == "AMD" and threads == 64:
|
||||
def a_elem(x, k, row, goff): return x[k%4][goff + (k//4)*16 + row]
|
||||
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order
|
||||
def c_map(lane, elem): return (lane%16, (lane//16)*4 + elem)
|
||||
ul[i] = wmma_helper(64, 16, 4, 4, 4, a_elem, b_elem, c_map)
|
||||
elif arg[4] == "AMD" and len(inp[0]) == 8: # RDNA4
|
||||
elif device == "AMD" and len(inp[0]) == 8: # RDNA4
|
||||
def a_elem(x, k, row, goff): return x[k - [0, 4, 4, 8][k//4]][goff + row + [0, 16, 0, 16][k//4]]
|
||||
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff)
|
||||
def c_map(lane, elem): return (lane%16, (lane//16)*8 + elem)
|
||||
ul[i] = wmma_helper(32, 16, 8, 8, 8, a_elem, b_elem, c_map)
|
||||
elif arg[4] == "AMD":
|
||||
elif device == "AMD":
|
||||
# A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
|
||||
def a_elem(x, k, row, goff):
|
||||
assert x[k][goff+row] == x[k][goff+row+16], "warp elements not duplicated properly across lanes"
|
||||
@@ -149,27 +165,27 @@ class PythonProgram:
|
||||
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order
|
||||
def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
|
||||
ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
|
||||
elif arg[4] == "CUDA":
|
||||
elif device == "CUDA":
|
||||
# (col, row) given (lane, elem) for C & D (4 elements on 32 threads); shared by all tc shapes with M=16 N=8
|
||||
def c_map(lane, elem): return (elem%2 + (lane%4)*2, lane//4 + (elem//2)*8)
|
||||
|
||||
if arg[1] == (8,16,16):
|
||||
if dims == (8,16,16):
|
||||
def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2 + (k//8)*4][goff + (k//2)%4 + (row%8)*4]
|
||||
def b_elem(x, col, k, goff): return x[k%2 + (k//8)*2][goff + (k//2)%4 + col*4]
|
||||
ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
|
||||
|
||||
elif arg[1] == (8,16,8) and arg[2] == dtypes.half:
|
||||
elif dims == (8,16,8) and dtype_in == dtypes.half:
|
||||
def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2][goff + k//2 + (row%8)*4]
|
||||
def b_elem(x, col, k, goff): return x[k%2][goff + k//2 + col*4]
|
||||
ul[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
|
||||
|
||||
elif arg[1] == (8,16,8) and arg[2] == dtypes.float:
|
||||
elif dims == (8,16,8) and dtype_in == dtypes.float:
|
||||
def a_elem(x, k, row, goff): return x[(k//4)*2 + row//8][goff + k%4 + (row%8)*4]
|
||||
def b_elem(x, col, k, goff): return x[k//4][goff + k%4 + col*4]
|
||||
ul[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
|
||||
|
||||
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
|
||||
elif arg[4] == "INTEL":
|
||||
elif device == "INTEL":
|
||||
# A (16 elements on 8 threads)
|
||||
def a_elem(x, k, row, goff): return x[k%2+row*2][goff+k//2]
|
||||
# B (16 elements on 8 threads)
|
||||
@@ -177,14 +193,14 @@ class PythonProgram:
|
||||
# C, D (8 elements on 8 threads)
|
||||
def c_map(lane, elem): return (lane, elem)
|
||||
ul[i] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map)
|
||||
elif arg[4] == "CPU":
|
||||
elif device == "CPU":
|
||||
def elem(x, col, row, _): return x[col+row][0] # k is always 0
|
||||
def c_map(_, elem): return (elem%16, elem//16)
|
||||
ul[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
|
||||
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
|
||||
elif uop in GroupOp.ALU:
|
||||
assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {uop}"
|
||||
assert all_same([dtype] + dtp) or uop in {Ops.CMPNE, Ops.CMPLT, Ops.WHERE}, f"dtype mismatch on {uop}"
|
||||
assert all_same([dtype] + dtp) or uop in {*GroupOp.Comparison, Ops.WHERE}, f"dtype mismatch on {uop}"
|
||||
ul[i] = [exec_alu(uop, dtype, p) for p in zip(*inp)]
|
||||
assert i in ul, (uop, dtype, idp, arg)
|
||||
i += 1
|
||||
@@ -192,6 +208,7 @@ class PythonProgram:
|
||||
|
||||
class PythonRenderer(Renderer):
|
||||
device = "PYTHON"
|
||||
code_for_op = python_alu
|
||||
def __init__(self):
|
||||
if getenv("EMULATE_METAL"): self.device, self.tensor_cores = "METAL", tc.metal
|
||||
if getenv("EMULATE_AMD"): self.device, self.tensor_cores = "AMD", tc.amd_rdna3
|
||||
|
||||
Reference in New Issue
Block a user