Merge branch 'master' into faster-approx-fix-cycle-graph

This commit is contained in:
hikettei
2024-07-04 08:29:23 +09:00
committed by GitHub
12 changed files with 153 additions and 159 deletions

View File

@@ -1,7 +1,6 @@
import os, sys, math, argparse, time
sys.path.append(os.getcwd())
from typing import Any, Optional, Dict
from dataclasses import dataclass, field
from tinygrad import Tensor, TinyJit, nn
from tinygrad.helpers import fetch
@@ -145,56 +144,52 @@ class MambaMixer:
self.out_proj = nn.Linear(self.d_inner, self.dim, bias=bias)
def __call__(self, hidden_states: Tensor, inference_params=None):
batch, seqlen, dim = hidden_states.shape
def __call__(self, hidden_states: Tensor):
batch, seqlen, _ = hidden_states.shape
conv_state, ssm_state = None, None
if inference_params is not None:
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
if inference_params.seqlen_offset > 0:
# The states are updated inplace
out, _, _ = self.step(hidden_states[:, -1:, :], conv_state, ssm_state)
return out
if not hasattr(self, 'conv_state'):
self.conv_state = Tensor.zeros(batch, self.dim * self.expand, self.d_conv).contiguous().realize()
self.ssm_state = Tensor.zeros(batch, self.dim * self.expand, self.d_state).realize()
xz = self.in_proj.weight @ hidden_states.permute(2,0,1).reshape(hidden_states.shape[2],hidden_states.shape[1]*hidden_states.shape[0])
xz = xz.reshape(xz.shape[0],xz.shape[1]//seqlen, seqlen).permute(1,0,2)
xz = self.in_proj.weight @ hidden_states.permute(2,0,1).reshape(hidden_states.shape[2],hidden_states.shape[1]*hidden_states.shape[0])
xz = xz.reshape(xz.shape[0],xz.shape[1]//seqlen, seqlen).permute(1,0,2)
if self.in_proj.bias is not None:
xz = xz + self.in_proj.bias.reshape((-1, 1))
if self.in_proj.bias is not None:
xz = xz + self.in_proj.bias.reshape((-1, 1))
A = -self.A_log.exp()
x, z = xz.chunk(2, dim=1)
# Compute short convolution
if conv_state is not None:
conv_state.assign(x[:, :, -self.d_conv :]) # Update state (B D W)
A = -self.A_log.exp()
x, z = xz.chunk(2, dim=1)
# Compute short convolution
self.conv_state.assign(x[:, :, -self.d_conv :]) # Update state (B D W)
x = self.conv1d(x)[..., :seqlen].swish()
x_dbl = self.x_proj(x.permute(0,2,1).reshape(x.shape[0]*x.shape[2], x.shape[1]))
dt, B, C = Tensor.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = self.dt_proj.weight @ dt.T
dt = dt.reshape(dt.shape[0], dt.shape[1]//seqlen, seqlen).permute(1,0,2)
B = B.reshape(B.shape[0]//seqlen, seqlen, B.shape[1]).permute(0,2,1)
C = C.reshape(C.shape[0]//seqlen, seqlen, C.shape[1]).permute(0,2,1)
x_dbl = self.x_proj(x.permute(0,2,1).reshape(x.shape[0]*x.shape[2], x.shape[1]))
dt, B, C = Tensor.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = self.dt_proj.weight @ dt.T
dt = dt.reshape(dt.shape[0], dt.shape[1]//seqlen, seqlen).permute(1,0,2)
B = B.reshape(B.shape[0]//seqlen, seqlen, B.shape[1]).permute(0,2,1)
C = C.reshape(C.shape[0]//seqlen, seqlen, C.shape[1]).permute(0,2,1)
# TODO: actually implement selective_scan_fn
y = selective_scan_ref(x, dt, A, B, C, self.D, z=z, delta_bias=self.dt_proj.bias, delta_softplus=True,
return_last_state=True)
# TODO: actually implement selective_scan_fn
y = selective_scan_ref(x, dt, A, B, C, self.D, z=z, delta_bias=self.dt_proj.bias, delta_softplus=True,
return_last_state=ssm_state is not None)
if ssm_state is not None:
y, last_state = y
ssm_state.assign(last_state)
self.ssm_state.assign(last_state).realize()
y = y.permute(0,2,1)
out = self.out_proj(y)
return out
else:
return self.step(hidden_states)
y = y.permute(0,2,1)
out = self.out_proj(y)
return out
def step(self, hidden_states: Tensor, conv_state: Tensor, ssm_state: Tensor):
def step(self, hidden_states: Tensor):
assert hidden_states.shape[1] == 1, f"Only support decoding with 1 token at a time for now, attempted {hidden_states.shape[1]}"
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
x, z = xz.chunk(2, dim=-1) # (B D)
# Conv step
conv_state.assign(conv_state[:, :, 1:].cat(x.unsqueeze(-1), dim=-1))
x = (conv_state * self.conv1d.weight.squeeze(1)).sum(-1)
self.conv_state.assign(self.conv_state[:, :, 1:].cat(x.unsqueeze(-1), dim=-1).realize())
x = (self.conv_state * self.conv1d.weight.squeeze(1)).sum(-1)
if self.conv1d.bias is not None:
x = x + self.conv1d.bias
x = x.swish()
@@ -211,23 +206,13 @@ class MambaMixer:
dt = (dt + self.dt_proj.bias.unsqueeze(-1)).softplus()
dA = Tensor.einsum("db,dn->bdn", dt, A).exp()
dB = Tensor.einsum("db,bn->bdn", dt, B)
ssm_state.assign(ssm_state * dA + x.unsqueeze(-1) * dB)
y = Tensor.einsum("bdn,bn->bd", ssm_state, C)
self.ssm_state.assign(self.ssm_state * dA + x.unsqueeze(-1) * dB)
y = Tensor.einsum("bdn,bn->bd", self.ssm_state, C)
y = y + self.D * x
y = y * z.swish() # (B D)
out = self.out_proj(y)
return out.unsqueeze(1), conv_state, ssm_state
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
assert self.layer_idx is not None
if self.layer_idx not in inference_params.key_value_memory_dict:
conv_state = Tensor.zeros(batch_size, self.dim * self.expand, self.d_conv).contiguous().realize()
ssm_state = Tensor.zeros(batch_size, self.dim * self.expand, self.d_state).realize()
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
else:
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
return conv_state, ssm_state
return out.unsqueeze(1)
class MambaBlock:
def __init__(self, dim: int, norm_eps: float = 1e-5, rms_norm: bool = True, layer_idx: Optional[int] = None):
@@ -237,10 +222,10 @@ class MambaBlock:
else:
raise NotImplementedError
def __call__(self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None):
def __call__(self, hidden_states: Tensor, residual: Optional[Tensor] = None):
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm(residual)
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
hidden_states = self.mixer(hidden_states)
return hidden_states, residual
class MambaBackbone:
@@ -250,15 +235,14 @@ class MambaBackbone:
if rms_norm:
self.norm_f = nn.RMSNorm(dim, norm_eps)
def __call__(self, input_ids: Tensor, inference_params=None) -> Any:
def __call__(self, input_ids: Tensor) -> Any:
hidden_states = self.embedding(input_ids)
residual = None
for layer in self.layers:
hidden_states, residual = layer(hidden_states, residual, inference_params=inference_params)
hidden_states, residual = layer(hidden_states, residual)
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm_f(residual)
return hidden_states
class Mamba:
@@ -271,19 +255,13 @@ class Mamba:
self.forward_jit = TinyJit(self.forward)
def forward(self, input_ids, inference_params, num_last_tokens):
hidden_states = self.backbone(input_ids, inference_params=inference_params)
if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:]
def forward(self, input_ids:Tensor):
hidden_states = self.backbone(input_ids)
return self.lm_head(hidden_states).realize()
def __call__(self, input_ids, inference_params=None, num_last_tokens=0, jit=True):
if inference_params is None:
return self.forward(input_ids, inference_params, num_last_tokens)
if jit and inference_params.seqlen_offset > 0:
return self.forward_jit(input_ids, inference_params, num_last_tokens)
else:
return self.forward(input_ids, inference_params, num_last_tokens)
def __call__(self, input_ids):
return self.forward(input_ids)
@staticmethod
def from_pretrained(model_name: str):
weights = fetch_weights(model_name)
@@ -292,56 +270,47 @@ class Mamba:
return model
@dataclass
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
max_seqlen: int
max_batch_size: int
seqlen_offset: int = 0
batch_size_offset: int = 0
key_value_memory_dict: dict = field(default_factory=dict)
lengths_per_sample: Optional[Tensor] = None
def reset(self, max_seqlen, max_batch_size):
self.max_seqlen = max_seqlen
self.max_batch_size = max_batch_size
self.seqlen_offset = 0
if self.lengths_per_sample is not None:
self.lengths_per_sample.zero_()
def generate(model, tokenizer, prompt: str, n_tokens_to_gen: int = 10, sample: bool = False, top_k: int = None):
def generate(model, tokenizer, prompt: str, n_tokens_to_gen: int = 10, temp: bool = 1.0, sample: bool = False, top_k: int = None):
tks = tokenizer(prompt)["input_ids"]
while len(tks) < 4:
tks = [50279] + tks
# TODO: sampling
temperature = 0.5
start_pos = 0
inference_params = InferenceParams(max_seqlen=1, max_batch_size=1, seqlen_offset=0)
# Loading in the prompt tokens
logits = model.forward(Tensor([tks]))[:, -1, :]
for _ in tqdm(range(n_tokens_to_gen), desc="Speed Gen"):
logits = model(Tensor([tks[start_pos:]]), inference_params, start_pos, jit=False)
inference_params.seqlen_offset = len(tks)
tok = logits[:, -1, :].argmax(axis=-1).item()
start_pos = len(tks)
# TODO: topk
if sample:
tok_Tens = (logits/temp).softmax().multinomial()
else:
tok_Tens = logits.argmax(axis=-1).unsqueeze(0)
tok = tok_Tens.item()
tks.append(tok)
logits = model.forward_jit(tok_Tens)[:, -1, :]
output_completions = ''.join([tokenizer.decode(output) for output in tks])
return output_completions
if __name__ == "__main__":
ORIG_PROMPT = "Why is gravity "
parser = argparse.ArgumentParser(description="Run Mamba in tinygrad", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--prompt", type=str, default="Why is gravity ", help="Prompt for LLM completion")
parser.add_argument("--size", type=str, default="370m",
help=f"Size of model to use [{', '.join([k for k in MODELS.keys()])}]")
parser.add_argument("--n_tokens", type=int, default=10, help="Number of tokens to generate")
parser.add_argument("--sample", dest="sample", action="store_true", help="Sample flag")
parser.add_argument("--temp", type=float, default=1.0, help="Sampling temp has to be <=1.0")
args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = Mamba.from_pretrained(args.size)
prompt = args.prompt
num_toks = args.n_tokens
sample = args.sample
temp = args.temp
s = time.time()
tinyoutput = generate(model, tokenizer, prompt, n_tokens_to_gen=num_toks)
tinyoutput = generate(model, tokenizer, prompt, n_tokens_to_gen=num_toks, sample=sample, temp=temp)
print(tinyoutput)
print('TIME: ', time.time() - s)
TORCHOUTPUT = "Why is gravity \nso important?\nBecause it's the only"
print('Outputs Match:', tinyoutput == TORCHOUTPUT)
if ORIG_PROMPT == prompt and not sample and num_toks==10 and args.size=='370m': print('Outputs Match:', tinyoutput == TORCHOUTPUT)

View File

@@ -171,17 +171,6 @@ if Device.DEFAULT == "METAL" or (OSX and Device.DEFAULT == "GPU"):
backend_test.exclude('test_mish_cpu')
backend_test.exclude('test_mish_expanded_cpu')
# TODO: llvm has problems with inf
if Device.DEFAULT in ['LLVM']:
backend_test.exclude('test_isinf_cpu')
backend_test.exclude('test_isinf_negative_cpu')
backend_test.exclude('test_isinf_positive_cpu')
# # TODO: problems with nan
if Device.DEFAULT in ['LLVM']:
backend_test.exclude('test_isnan_float16_cpu')
backend_test.exclude('test_isnan_cpu')
# disable model tests for now since they are slow
if not getenv("MODELTESTS"):
for x in backend_test.test_suite:

View File

@@ -24,7 +24,6 @@ binary_operations = [operator.add, operator.sub, operator.mul, operator.lt, oper
# TODO: LLVM comparing with nan is incorrect
if Device.DEFAULT == "LLVM":
binary_operations.remove(operator.lt)
binary_operations.remove(operator.eq)
integer_binary_operations = binary_operations + [(Tensor.xor, np.bitwise_xor), (Tensor.bitwise_and, np.bitwise_and),
(Tensor.bitwise_or, np.bitwise_or)]

View File

@@ -392,6 +392,8 @@ class TestJitInsideJit(unittest.TestCase):
@TinyJit
def g(t): return f(t) * 3
# NOTE: first does not raise
g(Tensor([1])).realize()
with self.assertRaisesRegex(RuntimeError, "having TinyJit inside another TinyJit is not supported"):
g(Tensor([1])).realize()

View File

@@ -777,8 +777,9 @@ class TestLinearizer(unittest.TestCase):
# check that the float4 cast collapses for all stores
for store in local_stores+global_stores:
assert store.src[2].dtype == dtypes.float.vec(2) and store.src[2].op is not UOps.CAST
# check the children's vins
assert barrier.src == tuple(local_stores)
# # check the children's vins
# TODO: src ALU are not the same, should it?
# assert barrier.src == tuple(local_stores)
assert len([u for u in k.uops if u.op is UOps.IF and u.src[-1] == barrier]) == 1
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")

View File

@@ -319,5 +319,14 @@ class TestAssembly(unittest.TestCase):
self.assertEqual(uops.uops[-1].arg, BinaryOps.IDIV)
self.assertEqual(uops.uops[-2].arg, BinaryOps.SHR)
class TestUOpCompare(unittest.TestCase):
def test_alu_same_src_different_arg(self):
a = UOp(UOps.CONST, dtypes.float, (), 2.0)
b = UOp(UOps.CONST, dtypes.float, (), 3.0)
add = UOp(UOps.ALU, dtypes.float, (a, b), BinaryOps.ADD)
mul = UOp(UOps.ALU, dtypes.float, (a, b), BinaryOps.MUL)
assert (add < mul) or (mul < add), "add and mul with same src should have an order"
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -60,5 +60,21 @@ class TestVerifyLazyOp(unittest.TestCase):
out = LazyOp(BufferOps.STORE, (r, ), MemBuffer(0, dtypes.int, ShapeTracker.from_shape((32, 1))))
with self.assertRaises(InvalidLazyOpException): lower(out)
def test_reduce_add_store(self):
a = LazyOp(BufferOps.LOAD, arg=MemBuffer(1, dtypes.int, ShapeTracker.from_shape((32, 1))))
r = LazyOp(ReduceOps.SUM, (a, ), (0, ))
out = LazyOp(BufferOps.STORE, (r+a, ), MemBuffer(0, dtypes.int, ShapeTracker.from_shape((32, 1))))
with self.assertRaises(InvalidLazyOpException): lower(out)
def test_multi_reduce_simple(self):
early_st = ShapeTracker.from_shape((32, 32)).reshape((32, 1, 32)).expand((32, 32, 32))
early_x = LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=early_st))
r0 = LazyOp(op=ReduceOps.SUM, src=(early_x, ), arg=(1, ))
late_st = ShapeTracker.from_shape((32, 32)).reshape((32, 1, 32))
late_x = LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=late_st))
r1 = LazyOp(op=ReduceOps.SUM, src=(late_x + r0, ), arg=(0, 1, 2))
out = LazyOp(op=BufferOps.STORE, src=(r1, ), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape((1, 1, 1))))
lower(out)
if __name__ == '__main__':
unittest.main()

View File

@@ -39,7 +39,7 @@ class UOp:
def cmp_tuple(self):
# NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg.expr) if self.op is not UOps.ALU else \
(type(self.op), self.op.value), self.dtype, self.src)
self.arg.value, self.dtype, self.src)
def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
def __repr__(self):
return f"{str(self.op):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.op for x in self.src]):32s} {self.arg}"
@@ -55,6 +55,7 @@ class UOp:
def __truediv__(self, x): return UOp.alu(BinaryOps.MUL, self, UOp.alu(UnaryOps.RECIP, ufix(self.dtype, x)))
def __mod__(self, x): return UOp.alu(BinaryOps.MOD, self, ufix(self.dtype, x))
def ne(self, x): return UOp.alu(BinaryOps.CMPNE, self, ufix(self.dtype, x))
def eq(self, x): return -self.ne(x)
def lt(self, x): return UOp.alu(BinaryOps.CMPLT, self, ufix(self.dtype, x))
def ge(self, x): return -self.lt(x)
def max(self, x): return UOp.alu(BinaryOps.MAX, self, x)
@@ -79,6 +80,44 @@ class UOp:
@property # parents with self
def sparents(self) -> Set[UOp]: return set([self]).union(self.parents)
def vars(self) -> Set[UOp]: return set([x for x in set.union(set([self]), self.parents) if x.op is UOps.DEFINE_VAR])
def divides(self, v):
if self.op is UOps.CONST:
return self.arg%v == 0
if self.op is UOps.ALU:
if self.arg is BinaryOps.ADD: return all(x.divides(v) for x in self.src)
if self.arg is BinaryOps.MUL: return any(x.divides(v) for x in self.src)
return False # generic false if we aren't sure
def type_verify(uops):
for u in uops:
uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
if uop in (UOps.CONST, UOps.DEFINE_ACC):
if uop is UOps.DEFINE_ACC:
assert dtype is not None and src[0].dtype == dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}"
arg = src[0].arg
assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg
if uop is UOps.CAST and dtype is not None and dtype.count > 1: assert len(src) == dtype.count
if uop is UOps.LOAD and len(src) > 3 and src[2].op is UOps.ALU: assert src[2].dtype == dtypes.bool and src[3].dtype == dtype
if uop is UOps.STORE and len(src) == 4: assert src[3].dtype == dtypes.bool
if uop is UOps.ALU:
if arg in UnaryOps:
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
elif arg in (BinaryOps.CMPLT, BinaryOps.CMPNE):
assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
elif arg is BinaryOps.IDIV:
assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), \
f"input dtype mismatch {dtypes.int} != {src[0].dtype=} != {src[1].dtype=}"
assert dtypes.is_int(dtype), f"{arg} output dtype mismatch {dtype=} != {dtypes.int}"
elif arg in {BinaryOps.SHL, BinaryOps.SHR}:
# the distance to shift isn't typechecked
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
elif arg in BinaryOps:
assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
elif arg == TernaryOps.WHERE:
assert src[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {src[0].dtype=} != {dtypes.bool}"
assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"
def uop_alu_resolve(u:UOp) -> sint:
if u.op is UOps.CONST: return u.arg
@@ -99,11 +138,13 @@ class UPat:
name: Optional[str] = None
dtype: Optional[Union[DType, Set[DType]]] = None
allow_len: Set[int] = field(default_factory=set)
allow_any_len: bool = False
@staticmethod
def compile(u: UOp, name:Optional[str]=None) -> UPat:
if u.op is UOps.VAR: return UPat(name=name or u.arg, dtype=u.dtype) if len(u.src) == 0 else UPat.compile(u.src[0], name or u.arg)
return UPat(u.op, u.arg, (list if u.commutative() else tuple)([UPat.compile(src) for src in u.src]) if u.src != () else None, name, u.dtype)
return UPat(u.op, u.arg, (list if u.commutative() else tuple)([UPat.compile(src) for src in u.src]) if u.src != () else None,
name, u.dtype, allow_any_len=(isinstance(name, str) and 'allow_any_len' in name))
T = TypeVar("T")
def __unmatch(m1:Union[T, Set[T]], m2:T) -> bool: return m2 not in m1 if isinstance(m1, set) else m2 != m1
@@ -118,7 +159,7 @@ def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> bool:
# try all permutations if it's a list
# repeat if it's a UPat
for vp in itertools.permutations(pat.src) if isinstance(pat.src,list) else ([pat.src] if isinstance(pat.src,tuple) else [(pat.src,)*len(uop.src)]):
if len(uop.src) != len(vp) and (len(uop.src) not in pat.allow_len): return False
if len(uop.src) != len(vp) and (len(uop.src) not in pat.allow_len) and not pat.allow_any_len: return False
new_store = store.copy()
if all(_match(uu, vv, new_store) for uu, vv in zip(uop.src, vp)):
store.update(new_store)
@@ -141,7 +182,7 @@ class PatternMatcher:
def rewrite(self, uop:UOp) -> Optional[UOp]:
for p,fxn in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]):
store: Dict[str, UOp] = {}
if _match(uop, p, store): return fxn(**store)
if _match(uop, p, store) and (ret:=fxn(**store)) is not None: return ret # NOTE: if it returns None, we keep trying to match
return None
def sum_collapse(phi_input, loop, val1, val2):
@@ -328,7 +369,7 @@ class UOpGraph:
for i,u in enumerate(self):
print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str([self.uops.index(x) for x in u.src]):32s} {u.arg}")
def linearize(self, extra_pm:Optional[PatternMatcher]=None, type_verify=True):
def linearize(self, extra_pm:Optional[PatternMatcher]=None, do_type_verify=True):
# NOTE: relinearizering should be okay
#assert self._uops is None, "already linearized"
@@ -398,7 +439,7 @@ class UOpGraph:
assert self._uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}"
self._uops = self._uops[:-1]
if type_verify: self.type_verify()
if do_type_verify: type_verify(self.uops)
# *** checker functions ***
@@ -434,33 +475,3 @@ class UOpGraph:
assert u.arg[1] is not None
flops += 2 * prod(u.arg[1]) // 32 * mults
return flops, mem
def type_verify(self):
for u in self.uops:
uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
if uop in (UOps.CONST, UOps.DEFINE_ACC):
if uop is UOps.DEFINE_ACC:
assert dtype is not None and src[0].dtype == dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}"
arg = src[0].arg
assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg
if uop is UOps.LOAD and len(src) > 3 and src[2].op is UOps.ALU: assert src[2].dtype == dtypes.bool and src[3].dtype == dtype
if uop is UOps.STORE and len(src) == 4: assert src[3].dtype == dtypes.bool
if uop is UOps.ALU:
if arg in UnaryOps:
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
elif arg in (BinaryOps.CMPLT, BinaryOps.CMPNE):
assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
elif arg is BinaryOps.IDIV:
assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), \
f"input dtype mismatch {dtypes.int} != {src[0].dtype=} != {src[1].dtype=}"
assert dtypes.is_int(dtype), f"{arg} output dtype mismatch {dtype=} != {dtypes.int}"
elif arg in {BinaryOps.SHL, BinaryOps.SHR}:
# the distance to shift isn't typechecked
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
elif arg in BinaryOps:
assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
elif arg == TernaryOps.WHERE:
assert src[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {src[0].dtype=} != {dtypes.bool}"
assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"

View File

@@ -3,7 +3,7 @@ from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, O
import functools, itertools, collections
from tinygrad.tensor import Tensor
from tinygrad.lazy import LazyBuffer
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, ContextVar, GRAPH, BEAM, getenv, all_int, GraphException, colored, JIT
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, all_int, GraphException, colored, JIT
from tinygrad.device import Buffer, Compiled, Device
from tinygrad.dtype import DType
from tinygrad.shape.shapetracker import ShapeTracker
@@ -106,7 +106,6 @@ class MultiGraphRunner(GraphRunner): # pylint: disable=abstract-method
return list({id(x):x for x in wait_nodes}.values())
ReturnType = TypeVar('ReturnType')
IN_JIT = ContextVar('IN_JIT', 0)
class TinyJit(Generic[ReturnType]):
def __init__(self, fxn:Callable[..., ReturnType]):
self.fxn = fxn
@@ -146,20 +145,22 @@ class TinyJit(Generic[ReturnType]):
[dict(v.unbind() for v in itertools.chain(args, kwargs.values()) if isinstance(v, Variable))])
st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varvals_dtype_device]
if not JIT or self.cnt == 0:
if IN_JIT: raise RuntimeError("having TinyJit inside another TinyJit is not supported")
# jit ignore
with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value, IN_JIT=1):
with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value):
self.ret = self.fxn(*args, **kwargs)
if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
elif self.cnt == 1:
# jit capture
self.expected_names: List[Union[int, str]] = names
self.expected_st_vars_dtype_device: List[Tuple[ShapeTracker, Tuple[Variable, ...], DType, str]] = st_vars_dtype_device
if capturing: raise RuntimeError(f"having TinyJit inside another TinyJit is not supported {len(capturing)=} {capturing=}")
with Context(GRAPH=getenv("JITGRAPH", GRAPH.value), BEAM=getenv("JITBEAM", BEAM.value)):
capturing.append(self)
self.ret = self.fxn(*args, **kwargs)
if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
capturing.clear()
try:
self.ret = self.fxn(*args, **kwargs)
if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
except Exception as e: raise e
finally: capturing.clear()
del self.buffer_replace
assert len(self.jit_cache), "didn't JIT anything!"
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_buffers)} inputs")

View File

@@ -154,9 +154,8 @@ def verify_lazyop(*ast:LazyOp):
for x in op.src: dfs(x, st)
# only reduceop is allowed to change shape, limited to turning n to 1
if op.op in ReduceOps:
expected_shape = tuple(1 if i in op.arg else s for i,s in enumerate(sts[op.src[0]].shape))
assert st.shape == expected_shape, f"unexpected reduceop shape {st.shape} != {expected_shape}"
st = ShapeTracker.from_shape(expected_shape)
assert isinstance(op.arg, tuple)
st = ShapeTracker.from_shape(tuple(1 if i in op.arg else s for i,s in enumerate(sts[op.src[0]].shape)))
else:
# movementops are pushed to the edges with LOAD
if op.op in BufferOps: st = op.arg.st

View File

@@ -86,10 +86,6 @@ class LLVMRenderer(Renderer):
for a in func.args:
if a.type.is_pointer: a.add_attribute("noalias")
# add the function attribute "no-nans-fp-math"="true", which informs llvm that it allowed to use vectorization optimizations
func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"']))
func.attributes.add('"no-nans-fp-math"="true"')
bb = [ir.IRBuilder(func.append_basic_block("entry"))]
loop_blocks: List = []
reduce_phis: List = []

View File

@@ -33,7 +33,9 @@ class MetalProgram:
with tempfile.NamedTemporaryFile(delete=True) as shader:
shader.write(lib)
shader.flush()
os.system(f"cd {pathlib.Path(__file__).parents[2]}/extra/disassemblers/applegpu && python3 compiler_explorer.py {shader.name}")
ret = os.system(f"cd {pathlib.Path(__file__).parents[2]}/extra/disassemblers/applegpu && python3 compiler_explorer.py {shader.name}")
if ret:
print("Error running disassembler: Make sure you have https://github.com/dougallj/applegpu cloned to tinygrad/extra/disassemblers/applegpu")
assert lib[:4] == b"MTLB", "Invalid Metal library. Could be due to using conda. Try system python or METAL_XCODE=1 DISABLE_COMPILER_CACHE=1."
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
self.library = unwrap2(self.device.device.newLibraryWithData_error_(data, None))