mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-15 17:40:13 +08:00
Merge branch 'master' into faster-approx-fix-cycle-graph
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import os, sys, math, argparse, time
|
||||
sys.path.append(os.getcwd())
|
||||
from typing import Any, Optional, Dict
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from tinygrad import Tensor, TinyJit, nn
|
||||
from tinygrad.helpers import fetch
|
||||
@@ -145,56 +144,52 @@ class MambaMixer:
|
||||
|
||||
self.out_proj = nn.Linear(self.d_inner, self.dim, bias=bias)
|
||||
|
||||
def __call__(self, hidden_states: Tensor, inference_params=None):
|
||||
batch, seqlen, dim = hidden_states.shape
|
||||
def __call__(self, hidden_states: Tensor):
|
||||
batch, seqlen, _ = hidden_states.shape
|
||||
|
||||
conv_state, ssm_state = None, None
|
||||
if inference_params is not None:
|
||||
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
|
||||
if inference_params.seqlen_offset > 0:
|
||||
# The states are updated inplace
|
||||
out, _, _ = self.step(hidden_states[:, -1:, :], conv_state, ssm_state)
|
||||
return out
|
||||
if not hasattr(self, 'conv_state'):
|
||||
self.conv_state = Tensor.zeros(batch, self.dim * self.expand, self.d_conv).contiguous().realize()
|
||||
self.ssm_state = Tensor.zeros(batch, self.dim * self.expand, self.d_state).realize()
|
||||
|
||||
xz = self.in_proj.weight @ hidden_states.permute(2,0,1).reshape(hidden_states.shape[2],hidden_states.shape[1]*hidden_states.shape[0])
|
||||
xz = xz.reshape(xz.shape[0],xz.shape[1]//seqlen, seqlen).permute(1,0,2)
|
||||
xz = self.in_proj.weight @ hidden_states.permute(2,0,1).reshape(hidden_states.shape[2],hidden_states.shape[1]*hidden_states.shape[0])
|
||||
xz = xz.reshape(xz.shape[0],xz.shape[1]//seqlen, seqlen).permute(1,0,2)
|
||||
|
||||
if self.in_proj.bias is not None:
|
||||
xz = xz + self.in_proj.bias.reshape((-1, 1))
|
||||
if self.in_proj.bias is not None:
|
||||
xz = xz + self.in_proj.bias.reshape((-1, 1))
|
||||
|
||||
A = -self.A_log.exp()
|
||||
x, z = xz.chunk(2, dim=1)
|
||||
# Compute short convolution
|
||||
if conv_state is not None:
|
||||
conv_state.assign(x[:, :, -self.d_conv :]) # Update state (B D W)
|
||||
A = -self.A_log.exp()
|
||||
x, z = xz.chunk(2, dim=1)
|
||||
# Compute short convolution
|
||||
self.conv_state.assign(x[:, :, -self.d_conv :]) # Update state (B D W)
|
||||
x = self.conv1d(x)[..., :seqlen].swish()
|
||||
|
||||
x_dbl = self.x_proj(x.permute(0,2,1).reshape(x.shape[0]*x.shape[2], x.shape[1]))
|
||||
dt, B, C = Tensor.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
||||
dt = self.dt_proj.weight @ dt.T
|
||||
dt = dt.reshape(dt.shape[0], dt.shape[1]//seqlen, seqlen).permute(1,0,2)
|
||||
B = B.reshape(B.shape[0]//seqlen, seqlen, B.shape[1]).permute(0,2,1)
|
||||
C = C.reshape(C.shape[0]//seqlen, seqlen, C.shape[1]).permute(0,2,1)
|
||||
x_dbl = self.x_proj(x.permute(0,2,1).reshape(x.shape[0]*x.shape[2], x.shape[1]))
|
||||
dt, B, C = Tensor.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
||||
dt = self.dt_proj.weight @ dt.T
|
||||
dt = dt.reshape(dt.shape[0], dt.shape[1]//seqlen, seqlen).permute(1,0,2)
|
||||
B = B.reshape(B.shape[0]//seqlen, seqlen, B.shape[1]).permute(0,2,1)
|
||||
C = C.reshape(C.shape[0]//seqlen, seqlen, C.shape[1]).permute(0,2,1)
|
||||
|
||||
# TODO: actually implement selective_scan_fn
|
||||
y = selective_scan_ref(x, dt, A, B, C, self.D, z=z, delta_bias=self.dt_proj.bias, delta_softplus=True,
|
||||
return_last_state=True)
|
||||
|
||||
# TODO: actually implement selective_scan_fn
|
||||
y = selective_scan_ref(x, dt, A, B, C, self.D, z=z, delta_bias=self.dt_proj.bias, delta_softplus=True,
|
||||
return_last_state=ssm_state is not None)
|
||||
if ssm_state is not None:
|
||||
y, last_state = y
|
||||
ssm_state.assign(last_state)
|
||||
self.ssm_state.assign(last_state).realize()
|
||||
y = y.permute(0,2,1)
|
||||
out = self.out_proj(y)
|
||||
return out
|
||||
else:
|
||||
return self.step(hidden_states)
|
||||
|
||||
y = y.permute(0,2,1)
|
||||
out = self.out_proj(y)
|
||||
return out
|
||||
|
||||
def step(self, hidden_states: Tensor, conv_state: Tensor, ssm_state: Tensor):
|
||||
def step(self, hidden_states: Tensor):
|
||||
assert hidden_states.shape[1] == 1, f"Only support decoding with 1 token at a time for now, attempted {hidden_states.shape[1]}"
|
||||
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
||||
x, z = xz.chunk(2, dim=-1) # (B D)
|
||||
|
||||
# Conv step
|
||||
conv_state.assign(conv_state[:, :, 1:].cat(x.unsqueeze(-1), dim=-1))
|
||||
x = (conv_state * self.conv1d.weight.squeeze(1)).sum(-1)
|
||||
self.conv_state.assign(self.conv_state[:, :, 1:].cat(x.unsqueeze(-1), dim=-1).realize())
|
||||
x = (self.conv_state * self.conv1d.weight.squeeze(1)).sum(-1)
|
||||
if self.conv1d.bias is not None:
|
||||
x = x + self.conv1d.bias
|
||||
x = x.swish()
|
||||
@@ -211,23 +206,13 @@ class MambaMixer:
|
||||
dt = (dt + self.dt_proj.bias.unsqueeze(-1)).softplus()
|
||||
dA = Tensor.einsum("db,dn->bdn", dt, A).exp()
|
||||
dB = Tensor.einsum("db,bn->bdn", dt, B)
|
||||
ssm_state.assign(ssm_state * dA + x.unsqueeze(-1) * dB)
|
||||
y = Tensor.einsum("bdn,bn->bd", ssm_state, C)
|
||||
self.ssm_state.assign(self.ssm_state * dA + x.unsqueeze(-1) * dB)
|
||||
y = Tensor.einsum("bdn,bn->bd", self.ssm_state, C)
|
||||
y = y + self.D * x
|
||||
y = y * z.swish() # (B D)
|
||||
|
||||
out = self.out_proj(y)
|
||||
return out.unsqueeze(1), conv_state, ssm_state
|
||||
|
||||
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
|
||||
assert self.layer_idx is not None
|
||||
if self.layer_idx not in inference_params.key_value_memory_dict:
|
||||
conv_state = Tensor.zeros(batch_size, self.dim * self.expand, self.d_conv).contiguous().realize()
|
||||
ssm_state = Tensor.zeros(batch_size, self.dim * self.expand, self.d_state).realize()
|
||||
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
|
||||
else:
|
||||
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
|
||||
return conv_state, ssm_state
|
||||
return out.unsqueeze(1)
|
||||
|
||||
class MambaBlock:
|
||||
def __init__(self, dim: int, norm_eps: float = 1e-5, rms_norm: bool = True, layer_idx: Optional[int] = None):
|
||||
@@ -237,10 +222,10 @@ class MambaBlock:
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None):
|
||||
def __call__(self, hidden_states: Tensor, residual: Optional[Tensor] = None):
|
||||
residual = (hidden_states + residual) if residual is not None else hidden_states
|
||||
hidden_states = self.norm(residual)
|
||||
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
|
||||
hidden_states = self.mixer(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
class MambaBackbone:
|
||||
@@ -250,15 +235,14 @@ class MambaBackbone:
|
||||
if rms_norm:
|
||||
self.norm_f = nn.RMSNorm(dim, norm_eps)
|
||||
|
||||
def __call__(self, input_ids: Tensor, inference_params=None) -> Any:
|
||||
def __call__(self, input_ids: Tensor) -> Any:
|
||||
hidden_states = self.embedding(input_ids)
|
||||
residual = None
|
||||
for layer in self.layers:
|
||||
hidden_states, residual = layer(hidden_states, residual, inference_params=inference_params)
|
||||
hidden_states, residual = layer(hidden_states, residual)
|
||||
|
||||
residual = (hidden_states + residual) if residual is not None else hidden_states
|
||||
hidden_states = self.norm_f(residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
class Mamba:
|
||||
@@ -271,19 +255,13 @@ class Mamba:
|
||||
|
||||
self.forward_jit = TinyJit(self.forward)
|
||||
|
||||
def forward(self, input_ids, inference_params, num_last_tokens):
|
||||
hidden_states = self.backbone(input_ids, inference_params=inference_params)
|
||||
if num_last_tokens > 0:
|
||||
hidden_states = hidden_states[:, -num_last_tokens:]
|
||||
def forward(self, input_ids:Tensor):
|
||||
hidden_states = self.backbone(input_ids)
|
||||
return self.lm_head(hidden_states).realize()
|
||||
|
||||
def __call__(self, input_ids, inference_params=None, num_last_tokens=0, jit=True):
|
||||
if inference_params is None:
|
||||
return self.forward(input_ids, inference_params, num_last_tokens)
|
||||
if jit and inference_params.seqlen_offset > 0:
|
||||
return self.forward_jit(input_ids, inference_params, num_last_tokens)
|
||||
else:
|
||||
return self.forward(input_ids, inference_params, num_last_tokens)
|
||||
def __call__(self, input_ids):
|
||||
return self.forward(input_ids)
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(model_name: str):
|
||||
weights = fetch_weights(model_name)
|
||||
@@ -292,56 +270,47 @@ class Mamba:
|
||||
|
||||
return model
|
||||
|
||||
@dataclass
|
||||
class InferenceParams:
|
||||
"""Inference parameters that are passed to the main model in order
|
||||
to efficienly calculate and store the context during inference."""
|
||||
max_seqlen: int
|
||||
max_batch_size: int
|
||||
seqlen_offset: int = 0
|
||||
batch_size_offset: int = 0
|
||||
key_value_memory_dict: dict = field(default_factory=dict)
|
||||
lengths_per_sample: Optional[Tensor] = None
|
||||
|
||||
def reset(self, max_seqlen, max_batch_size):
|
||||
self.max_seqlen = max_seqlen
|
||||
self.max_batch_size = max_batch_size
|
||||
self.seqlen_offset = 0
|
||||
if self.lengths_per_sample is not None:
|
||||
self.lengths_per_sample.zero_()
|
||||
|
||||
def generate(model, tokenizer, prompt: str, n_tokens_to_gen: int = 10, sample: bool = False, top_k: int = None):
|
||||
def generate(model, tokenizer, prompt: str, n_tokens_to_gen: int = 10, temp: bool = 1.0, sample: bool = False, top_k: int = None):
|
||||
tks = tokenizer(prompt)["input_ids"]
|
||||
while len(tks) < 4:
|
||||
tks = [50279] + tks
|
||||
# TODO: sampling
|
||||
temperature = 0.5
|
||||
start_pos = 0
|
||||
inference_params = InferenceParams(max_seqlen=1, max_batch_size=1, seqlen_offset=0)
|
||||
|
||||
# Loading in the prompt tokens
|
||||
logits = model.forward(Tensor([tks]))[:, -1, :]
|
||||
for _ in tqdm(range(n_tokens_to_gen), desc="Speed Gen"):
|
||||
logits = model(Tensor([tks[start_pos:]]), inference_params, start_pos, jit=False)
|
||||
inference_params.seqlen_offset = len(tks)
|
||||
tok = logits[:, -1, :].argmax(axis=-1).item()
|
||||
start_pos = len(tks)
|
||||
# TODO: topk
|
||||
if sample:
|
||||
tok_Tens = (logits/temp).softmax().multinomial()
|
||||
else:
|
||||
tok_Tens = logits.argmax(axis=-1).unsqueeze(0)
|
||||
tok = tok_Tens.item()
|
||||
tks.append(tok)
|
||||
logits = model.forward_jit(tok_Tens)[:, -1, :]
|
||||
|
||||
output_completions = ''.join([tokenizer.decode(output) for output in tks])
|
||||
return output_completions
|
||||
|
||||
if __name__ == "__main__":
|
||||
ORIG_PROMPT = "Why is gravity "
|
||||
parser = argparse.ArgumentParser(description="Run Mamba in tinygrad", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("--prompt", type=str, default="Why is gravity ", help="Prompt for LLM completion")
|
||||
parser.add_argument("--size", type=str, default="370m",
|
||||
help=f"Size of model to use [{', '.join([k for k in MODELS.keys()])}]")
|
||||
parser.add_argument("--n_tokens", type=int, default=10, help="Number of tokens to generate")
|
||||
parser.add_argument("--sample", dest="sample", action="store_true", help="Sample flag")
|
||||
parser.add_argument("--temp", type=float, default=1.0, help="Sampling temp has to be <=1.0")
|
||||
args = parser.parse_args()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
||||
model = Mamba.from_pretrained(args.size)
|
||||
prompt = args.prompt
|
||||
num_toks = args.n_tokens
|
||||
sample = args.sample
|
||||
temp = args.temp
|
||||
s = time.time()
|
||||
tinyoutput = generate(model, tokenizer, prompt, n_tokens_to_gen=num_toks)
|
||||
tinyoutput = generate(model, tokenizer, prompt, n_tokens_to_gen=num_toks, sample=sample, temp=temp)
|
||||
print(tinyoutput)
|
||||
print('TIME: ', time.time() - s)
|
||||
TORCHOUTPUT = "Why is gravity \nso important?\nBecause it's the only"
|
||||
print('Outputs Match:', tinyoutput == TORCHOUTPUT)
|
||||
if ORIG_PROMPT == prompt and not sample and num_toks==10 and args.size=='370m': print('Outputs Match:', tinyoutput == TORCHOUTPUT)
|
||||
11
test/external/external_test_onnx_backend.py
vendored
11
test/external/external_test_onnx_backend.py
vendored
@@ -171,17 +171,6 @@ if Device.DEFAULT == "METAL" or (OSX and Device.DEFAULT == "GPU"):
|
||||
backend_test.exclude('test_mish_cpu')
|
||||
backend_test.exclude('test_mish_expanded_cpu')
|
||||
|
||||
# TODO: llvm has problems with inf
|
||||
if Device.DEFAULT in ['LLVM']:
|
||||
backend_test.exclude('test_isinf_cpu')
|
||||
backend_test.exclude('test_isinf_negative_cpu')
|
||||
backend_test.exclude('test_isinf_positive_cpu')
|
||||
|
||||
# # TODO: problems with nan
|
||||
if Device.DEFAULT in ['LLVM']:
|
||||
backend_test.exclude('test_isnan_float16_cpu')
|
||||
backend_test.exclude('test_isnan_cpu')
|
||||
|
||||
# disable model tests for now since they are slow
|
||||
if not getenv("MODELTESTS"):
|
||||
for x in backend_test.test_suite:
|
||||
|
||||
@@ -24,7 +24,6 @@ binary_operations = [operator.add, operator.sub, operator.mul, operator.lt, oper
|
||||
# TODO: LLVM comparing with nan is incorrect
|
||||
if Device.DEFAULT == "LLVM":
|
||||
binary_operations.remove(operator.lt)
|
||||
binary_operations.remove(operator.eq)
|
||||
|
||||
integer_binary_operations = binary_operations + [(Tensor.xor, np.bitwise_xor), (Tensor.bitwise_and, np.bitwise_and),
|
||||
(Tensor.bitwise_or, np.bitwise_or)]
|
||||
|
||||
@@ -392,6 +392,8 @@ class TestJitInsideJit(unittest.TestCase):
|
||||
@TinyJit
|
||||
def g(t): return f(t) * 3
|
||||
|
||||
# NOTE: first does not raise
|
||||
g(Tensor([1])).realize()
|
||||
with self.assertRaisesRegex(RuntimeError, "having TinyJit inside another TinyJit is not supported"):
|
||||
g(Tensor([1])).realize()
|
||||
|
||||
|
||||
@@ -777,8 +777,9 @@ class TestLinearizer(unittest.TestCase):
|
||||
# check that the float4 cast collapses for all stores
|
||||
for store in local_stores+global_stores:
|
||||
assert store.src[2].dtype == dtypes.float.vec(2) and store.src[2].op is not UOps.CAST
|
||||
# check the children's vins
|
||||
assert barrier.src == tuple(local_stores)
|
||||
# # check the children's vins
|
||||
# TODO: src ALU are not the same, should it?
|
||||
# assert barrier.src == tuple(local_stores)
|
||||
assert len([u for u in k.uops if u.op is UOps.IF and u.src[-1] == barrier]) == 1
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
|
||||
@@ -319,5 +319,14 @@ class TestAssembly(unittest.TestCase):
|
||||
self.assertEqual(uops.uops[-1].arg, BinaryOps.IDIV)
|
||||
self.assertEqual(uops.uops[-2].arg, BinaryOps.SHR)
|
||||
|
||||
class TestUOpCompare(unittest.TestCase):
|
||||
def test_alu_same_src_different_arg(self):
|
||||
a = UOp(UOps.CONST, dtypes.float, (), 2.0)
|
||||
b = UOp(UOps.CONST, dtypes.float, (), 3.0)
|
||||
|
||||
add = UOp(UOps.ALU, dtypes.float, (a, b), BinaryOps.ADD)
|
||||
mul = UOp(UOps.ALU, dtypes.float, (a, b), BinaryOps.MUL)
|
||||
assert (add < mul) or (mul < add), "add and mul with same src should have an order"
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -60,5 +60,21 @@ class TestVerifyLazyOp(unittest.TestCase):
|
||||
out = LazyOp(BufferOps.STORE, (r, ), MemBuffer(0, dtypes.int, ShapeTracker.from_shape((32, 1))))
|
||||
with self.assertRaises(InvalidLazyOpException): lower(out)
|
||||
|
||||
def test_reduce_add_store(self):
|
||||
a = LazyOp(BufferOps.LOAD, arg=MemBuffer(1, dtypes.int, ShapeTracker.from_shape((32, 1))))
|
||||
r = LazyOp(ReduceOps.SUM, (a, ), (0, ))
|
||||
out = LazyOp(BufferOps.STORE, (r+a, ), MemBuffer(0, dtypes.int, ShapeTracker.from_shape((32, 1))))
|
||||
with self.assertRaises(InvalidLazyOpException): lower(out)
|
||||
|
||||
def test_multi_reduce_simple(self):
|
||||
early_st = ShapeTracker.from_shape((32, 32)).reshape((32, 1, 32)).expand((32, 32, 32))
|
||||
early_x = LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=early_st))
|
||||
r0 = LazyOp(op=ReduceOps.SUM, src=(early_x, ), arg=(1, ))
|
||||
late_st = ShapeTracker.from_shape((32, 32)).reshape((32, 1, 32))
|
||||
late_x = LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=late_st))
|
||||
r1 = LazyOp(op=ReduceOps.SUM, src=(late_x + r0, ), arg=(0, 1, 2))
|
||||
out = LazyOp(op=BufferOps.STORE, src=(r1, ), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape((1, 1, 1))))
|
||||
lower(out)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -39,7 +39,7 @@ class UOp:
|
||||
def cmp_tuple(self):
|
||||
# NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
|
||||
return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg.expr) if self.op is not UOps.ALU else \
|
||||
(type(self.op), self.op.value), self.dtype, self.src)
|
||||
self.arg.value, self.dtype, self.src)
|
||||
def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
|
||||
def __repr__(self):
|
||||
return f"{str(self.op):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.op for x in self.src]):32s} {self.arg}"
|
||||
@@ -55,6 +55,7 @@ class UOp:
|
||||
def __truediv__(self, x): return UOp.alu(BinaryOps.MUL, self, UOp.alu(UnaryOps.RECIP, ufix(self.dtype, x)))
|
||||
def __mod__(self, x): return UOp.alu(BinaryOps.MOD, self, ufix(self.dtype, x))
|
||||
def ne(self, x): return UOp.alu(BinaryOps.CMPNE, self, ufix(self.dtype, x))
|
||||
def eq(self, x): return -self.ne(x)
|
||||
def lt(self, x): return UOp.alu(BinaryOps.CMPLT, self, ufix(self.dtype, x))
|
||||
def ge(self, x): return -self.lt(x)
|
||||
def max(self, x): return UOp.alu(BinaryOps.MAX, self, x)
|
||||
@@ -79,6 +80,44 @@ class UOp:
|
||||
@property # parents with self
|
||||
def sparents(self) -> Set[UOp]: return set([self]).union(self.parents)
|
||||
def vars(self) -> Set[UOp]: return set([x for x in set.union(set([self]), self.parents) if x.op is UOps.DEFINE_VAR])
|
||||
def divides(self, v):
|
||||
if self.op is UOps.CONST:
|
||||
return self.arg%v == 0
|
||||
if self.op is UOps.ALU:
|
||||
if self.arg is BinaryOps.ADD: return all(x.divides(v) for x in self.src)
|
||||
if self.arg is BinaryOps.MUL: return any(x.divides(v) for x in self.src)
|
||||
return False # generic false if we aren't sure
|
||||
|
||||
def type_verify(uops):
|
||||
for u in uops:
|
||||
uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
|
||||
if uop in (UOps.CONST, UOps.DEFINE_ACC):
|
||||
if uop is UOps.DEFINE_ACC:
|
||||
assert dtype is not None and src[0].dtype == dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}"
|
||||
arg = src[0].arg
|
||||
assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
|
||||
if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg
|
||||
if uop is UOps.CAST and dtype is not None and dtype.count > 1: assert len(src) == dtype.count
|
||||
if uop is UOps.LOAD and len(src) > 3 and src[2].op is UOps.ALU: assert src[2].dtype == dtypes.bool and src[3].dtype == dtype
|
||||
if uop is UOps.STORE and len(src) == 4: assert src[3].dtype == dtypes.bool
|
||||
if uop is UOps.ALU:
|
||||
if arg in UnaryOps:
|
||||
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
||||
elif arg in (BinaryOps.CMPLT, BinaryOps.CMPNE):
|
||||
assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
|
||||
assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
||||
elif arg is BinaryOps.IDIV:
|
||||
assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), \
|
||||
f"input dtype mismatch {dtypes.int} != {src[0].dtype=} != {src[1].dtype=}"
|
||||
assert dtypes.is_int(dtype), f"{arg} output dtype mismatch {dtype=} != {dtypes.int}"
|
||||
elif arg in {BinaryOps.SHL, BinaryOps.SHR}:
|
||||
# the distance to shift isn't typechecked
|
||||
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
||||
elif arg in BinaryOps:
|
||||
assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
||||
elif arg == TernaryOps.WHERE:
|
||||
assert src[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {src[0].dtype=} != {dtypes.bool}"
|
||||
assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"
|
||||
|
||||
def uop_alu_resolve(u:UOp) -> sint:
|
||||
if u.op is UOps.CONST: return u.arg
|
||||
@@ -99,11 +138,13 @@ class UPat:
|
||||
name: Optional[str] = None
|
||||
dtype: Optional[Union[DType, Set[DType]]] = None
|
||||
allow_len: Set[int] = field(default_factory=set)
|
||||
allow_any_len: bool = False
|
||||
|
||||
@staticmethod
|
||||
def compile(u: UOp, name:Optional[str]=None) -> UPat:
|
||||
if u.op is UOps.VAR: return UPat(name=name or u.arg, dtype=u.dtype) if len(u.src) == 0 else UPat.compile(u.src[0], name or u.arg)
|
||||
return UPat(u.op, u.arg, (list if u.commutative() else tuple)([UPat.compile(src) for src in u.src]) if u.src != () else None, name, u.dtype)
|
||||
return UPat(u.op, u.arg, (list if u.commutative() else tuple)([UPat.compile(src) for src in u.src]) if u.src != () else None,
|
||||
name, u.dtype, allow_any_len=(isinstance(name, str) and 'allow_any_len' in name))
|
||||
|
||||
T = TypeVar("T")
|
||||
def __unmatch(m1:Union[T, Set[T]], m2:T) -> bool: return m2 not in m1 if isinstance(m1, set) else m2 != m1
|
||||
@@ -118,7 +159,7 @@ def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> bool:
|
||||
# try all permutations if it's a list
|
||||
# repeat if it's a UPat
|
||||
for vp in itertools.permutations(pat.src) if isinstance(pat.src,list) else ([pat.src] if isinstance(pat.src,tuple) else [(pat.src,)*len(uop.src)]):
|
||||
if len(uop.src) != len(vp) and (len(uop.src) not in pat.allow_len): return False
|
||||
if len(uop.src) != len(vp) and (len(uop.src) not in pat.allow_len) and not pat.allow_any_len: return False
|
||||
new_store = store.copy()
|
||||
if all(_match(uu, vv, new_store) for uu, vv in zip(uop.src, vp)):
|
||||
store.update(new_store)
|
||||
@@ -141,7 +182,7 @@ class PatternMatcher:
|
||||
def rewrite(self, uop:UOp) -> Optional[UOp]:
|
||||
for p,fxn in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]):
|
||||
store: Dict[str, UOp] = {}
|
||||
if _match(uop, p, store): return fxn(**store)
|
||||
if _match(uop, p, store) and (ret:=fxn(**store)) is not None: return ret # NOTE: if it returns None, we keep trying to match
|
||||
return None
|
||||
|
||||
def sum_collapse(phi_input, loop, val1, val2):
|
||||
@@ -328,7 +369,7 @@ class UOpGraph:
|
||||
for i,u in enumerate(self):
|
||||
print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str([self.uops.index(x) for x in u.src]):32s} {u.arg}")
|
||||
|
||||
def linearize(self, extra_pm:Optional[PatternMatcher]=None, type_verify=True):
|
||||
def linearize(self, extra_pm:Optional[PatternMatcher]=None, do_type_verify=True):
|
||||
# NOTE: relinearizering should be okay
|
||||
#assert self._uops is None, "already linearized"
|
||||
|
||||
@@ -398,7 +439,7 @@ class UOpGraph:
|
||||
assert self._uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}"
|
||||
self._uops = self._uops[:-1]
|
||||
|
||||
if type_verify: self.type_verify()
|
||||
if do_type_verify: type_verify(self.uops)
|
||||
|
||||
# *** checker functions ***
|
||||
|
||||
@@ -434,33 +475,3 @@ class UOpGraph:
|
||||
assert u.arg[1] is not None
|
||||
flops += 2 * prod(u.arg[1]) // 32 * mults
|
||||
return flops, mem
|
||||
|
||||
def type_verify(self):
|
||||
for u in self.uops:
|
||||
uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
|
||||
if uop in (UOps.CONST, UOps.DEFINE_ACC):
|
||||
if uop is UOps.DEFINE_ACC:
|
||||
assert dtype is not None and src[0].dtype == dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}"
|
||||
arg = src[0].arg
|
||||
assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
|
||||
if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg
|
||||
if uop is UOps.LOAD and len(src) > 3 and src[2].op is UOps.ALU: assert src[2].dtype == dtypes.bool and src[3].dtype == dtype
|
||||
if uop is UOps.STORE and len(src) == 4: assert src[3].dtype == dtypes.bool
|
||||
if uop is UOps.ALU:
|
||||
if arg in UnaryOps:
|
||||
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
||||
elif arg in (BinaryOps.CMPLT, BinaryOps.CMPNE):
|
||||
assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
|
||||
assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
||||
elif arg is BinaryOps.IDIV:
|
||||
assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), \
|
||||
f"input dtype mismatch {dtypes.int} != {src[0].dtype=} != {src[1].dtype=}"
|
||||
assert dtypes.is_int(dtype), f"{arg} output dtype mismatch {dtype=} != {dtypes.int}"
|
||||
elif arg in {BinaryOps.SHL, BinaryOps.SHR}:
|
||||
# the distance to shift isn't typechecked
|
||||
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
||||
elif arg in BinaryOps:
|
||||
assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
||||
elif arg == TernaryOps.WHERE:
|
||||
assert src[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {src[0].dtype=} != {dtypes.bool}"
|
||||
assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, O
|
||||
import functools, itertools, collections
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, ContextVar, GRAPH, BEAM, getenv, all_int, GraphException, colored, JIT
|
||||
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, all_int, GraphException, colored, JIT
|
||||
from tinygrad.device import Buffer, Compiled, Device
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -106,7 +106,6 @@ class MultiGraphRunner(GraphRunner): # pylint: disable=abstract-method
|
||||
return list({id(x):x for x in wait_nodes}.values())
|
||||
|
||||
ReturnType = TypeVar('ReturnType')
|
||||
IN_JIT = ContextVar('IN_JIT', 0)
|
||||
class TinyJit(Generic[ReturnType]):
|
||||
def __init__(self, fxn:Callable[..., ReturnType]):
|
||||
self.fxn = fxn
|
||||
@@ -146,20 +145,22 @@ class TinyJit(Generic[ReturnType]):
|
||||
[dict(v.unbind() for v in itertools.chain(args, kwargs.values()) if isinstance(v, Variable))])
|
||||
st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varvals_dtype_device]
|
||||
if not JIT or self.cnt == 0:
|
||||
if IN_JIT: raise RuntimeError("having TinyJit inside another TinyJit is not supported")
|
||||
# jit ignore
|
||||
with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value, IN_JIT=1):
|
||||
with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value):
|
||||
self.ret = self.fxn(*args, **kwargs)
|
||||
if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
|
||||
elif self.cnt == 1:
|
||||
# jit capture
|
||||
self.expected_names: List[Union[int, str]] = names
|
||||
self.expected_st_vars_dtype_device: List[Tuple[ShapeTracker, Tuple[Variable, ...], DType, str]] = st_vars_dtype_device
|
||||
if capturing: raise RuntimeError(f"having TinyJit inside another TinyJit is not supported {len(capturing)=} {capturing=}")
|
||||
with Context(GRAPH=getenv("JITGRAPH", GRAPH.value), BEAM=getenv("JITBEAM", BEAM.value)):
|
||||
capturing.append(self)
|
||||
self.ret = self.fxn(*args, **kwargs)
|
||||
if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
|
||||
capturing.clear()
|
||||
try:
|
||||
self.ret = self.fxn(*args, **kwargs)
|
||||
if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
|
||||
except Exception as e: raise e
|
||||
finally: capturing.clear()
|
||||
del self.buffer_replace
|
||||
assert len(self.jit_cache), "didn't JIT anything!"
|
||||
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_buffers)} inputs")
|
||||
|
||||
@@ -154,9 +154,8 @@ def verify_lazyop(*ast:LazyOp):
|
||||
for x in op.src: dfs(x, st)
|
||||
# only reduceop is allowed to change shape, limited to turning n to 1
|
||||
if op.op in ReduceOps:
|
||||
expected_shape = tuple(1 if i in op.arg else s for i,s in enumerate(sts[op.src[0]].shape))
|
||||
assert st.shape == expected_shape, f"unexpected reduceop shape {st.shape} != {expected_shape}"
|
||||
st = ShapeTracker.from_shape(expected_shape)
|
||||
assert isinstance(op.arg, tuple)
|
||||
st = ShapeTracker.from_shape(tuple(1 if i in op.arg else s for i,s in enumerate(sts[op.src[0]].shape)))
|
||||
else:
|
||||
# movementops are pushed to the edges with LOAD
|
||||
if op.op in BufferOps: st = op.arg.st
|
||||
|
||||
@@ -86,10 +86,6 @@ class LLVMRenderer(Renderer):
|
||||
for a in func.args:
|
||||
if a.type.is_pointer: a.add_attribute("noalias")
|
||||
|
||||
# add the function attribute "no-nans-fp-math"="true", which informs llvm that it allowed to use vectorization optimizations
|
||||
func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"']))
|
||||
func.attributes.add('"no-nans-fp-math"="true"')
|
||||
|
||||
bb = [ir.IRBuilder(func.append_basic_block("entry"))]
|
||||
loop_blocks: List = []
|
||||
reduce_phis: List = []
|
||||
|
||||
@@ -33,7 +33,9 @@ class MetalProgram:
|
||||
with tempfile.NamedTemporaryFile(delete=True) as shader:
|
||||
shader.write(lib)
|
||||
shader.flush()
|
||||
os.system(f"cd {pathlib.Path(__file__).parents[2]}/extra/disassemblers/applegpu && python3 compiler_explorer.py {shader.name}")
|
||||
ret = os.system(f"cd {pathlib.Path(__file__).parents[2]}/extra/disassemblers/applegpu && python3 compiler_explorer.py {shader.name}")
|
||||
if ret:
|
||||
print("Error running disassembler: Make sure you have https://github.com/dougallj/applegpu cloned to tinygrad/extra/disassemblers/applegpu")
|
||||
assert lib[:4] == b"MTLB", "Invalid Metal library. Could be due to using conda. Try system python or METAL_XCODE=1 DISABLE_COMPILER_CACHE=1."
|
||||
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
|
||||
self.library = unwrap2(self.device.device.newLibraryWithData_error_(data, None))
|
||||
|
||||
Reference in New Issue
Block a user