mirror of
https://github.com/firestar5683/StarPilot.git
synced 2026-07-01 03:22:07 +08:00
368 lines
21 KiB
Python
368 lines
21 KiB
Python
from __future__ import annotations
|
|
import sys, argparse, typing, re, unicodedata, json, uuid, time, functools
|
|
from tinygrad import Tensor, nn, UOp, TinyJit, getenv
|
|
from tinygrad.helpers import partition, DEBUG, Timing, GlobalCounters, stderr_log, colored
|
|
from tinygrad.viz.serve import TCPServerWithReuse, HTTPRequestHandler
|
|
|
|
class SimpleTokenizer:
|
|
def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int], preset:str="llama3"):
|
|
if preset not in ("llama3","llama-v3","llama-bpe","qwen2","olmo"): raise ValueError(f"Invalid tokenizer preset '{preset}'")
|
|
# https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/encoder.py#L9
|
|
bs = [*range(33, 127), *range(161, 173), *range(174, 256)] # bytes that map to themselves
|
|
self._byte_decoder = {chr(b): b for b in bs} | {chr(256+i): b for i,b in enumerate(b for b in range(256) if b not in bs)}
|
|
|
|
# https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L286
|
|
# 0x323b0 is one past the max codepoint in unicode categories L/N/Z (0x323af is max L)
|
|
def ucat_range(pre: str): return "".join(re.escape(chr(cp)) for cp in range(0x323b0) if unicodedata.category(chr(cp)).startswith(pre))
|
|
r_ws, r_p_N, r_p_L = r"\t\n\x0b\x0c\r\x85" + ucat_range("Z"), ucat_range("N"), ucat_range("L")
|
|
self._split_to_word = re.compile("(?i:'s|'t|'re|'ve|'m|'ll|'d)|" + \
|
|
f"[^\\r\\n{r_p_N}{r_p_L}]?[{r_p_L}]+|[{r_p_N}]{{1,3}}| ?[^{r_ws}{r_p_N}{r_p_L}]+[\\r\\n]*|[{r_ws}]*[\\r\\n]+|[{r_ws}]+(?![^{r_ws}])|[{r_ws}]+")
|
|
self._split_to_sentence = re.compile("|".join(re.escape(tok) for tok in special_tokens.keys()) if special_tokens else r"(?!)")
|
|
|
|
self._normal_tokens = {bytes(self._byte_decoder[c] for c in tok): tid for tok, tid in normal_tokens.items()}
|
|
self._special_tokens = special_tokens
|
|
self._tok2bytes = {tid: tok for tok, tid in self._normal_tokens.items()} | {tid: tok.encode() for tok, tid in self._special_tokens.items()}
|
|
self.preset = preset
|
|
|
|
@staticmethod
|
|
def from_gguf_kv(kv:dict):
|
|
# https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L1818-L1820
|
|
vocab: typing.Iterable[tuple[str, int]] = ((tok, idx) for idx, tok in enumerate(kv["tokenizer.ggml.tokens"]))
|
|
normal_tokens, special_tokens = partition(vocab, lambda e: kv["tokenizer.ggml.token_type"][e[1]] == 1)
|
|
return SimpleTokenizer(dict(normal_tokens), dict(special_tokens), kv["tokenizer.ggml.pre"])
|
|
|
|
def _encode_word(self, word:bytes) -> list[int]:
|
|
if (early_token:=self._normal_tokens.get(word)) is not None: return [early_token]
|
|
parts = [bytes([b]) for b in word]
|
|
# greedily merge any parts that we can
|
|
while True:
|
|
i = min([(sys.maxsize, -1)] + [(self._normal_tokens.get(parts[j]+parts[j+1], sys.maxsize), j) for j in range(len(parts)-1)])[1]
|
|
if i == -1: break
|
|
parts[i:i+2] = [parts[i] + parts[i+1]]
|
|
try: return [self._normal_tokens[p] for p in parts]
|
|
except KeyError: raise RuntimeError("token not found")
|
|
def _encode_sentence(self, chunk:str) -> list[int]:
|
|
return [tok for word in self._split_to_word.findall(chunk) for tok in self._encode_word(word.encode())]
|
|
def encode(self, text:str) -> list[int]:
|
|
tokens: list[int] = []
|
|
pos = 0
|
|
for match in self._split_to_sentence.finditer(text):
|
|
tokens.extend(self._encode_sentence(text[pos:match.start(0)]) + [self._special_tokens[text[match.start(0):match.end(0)]]])
|
|
pos = match.end(0)
|
|
return tokens + self._encode_sentence(text[pos:])
|
|
|
|
def decode(self, ids:list[int]) -> str: return b''.join(self._tok2bytes[tid] for tid in ids).decode(errors='replace')
|
|
def role(self, role:str):
|
|
if self.preset == 'olmo': return self.encode("<|" + role + "|>\n") # OLMoE Instruct format
|
|
if self.preset == 'qwen2': return self.encode("<|im_start|>" + role + "\n")
|
|
return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n")
|
|
def end_turn(self, eos_id:int):
|
|
if self.preset == 'olmo': return self.encode("\n")
|
|
if self.preset == 'qwen2': return [eos_id] + self.encode("\n")
|
|
return [eos_id]
|
|
|
|
@functools.cache
|
|
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
|
|
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
|
|
freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
|
|
return freqs.cos().cat(freqs.sin(), dim=-1).contiguous()
|
|
|
|
class ExpertWeights:
|
|
"""Like nn.Linear but with num_experts dimension. Weight shape: (num_experts, out_features, in_features)."""
|
|
def __init__(self, num_experts:int, in_features:int, out_features:int):
|
|
self.weight = Tensor.zeros(num_experts, out_features, in_features)
|
|
def __call__(self, sel:Tensor, x:Tensor) -> Tensor:
|
|
# sel: (B, T, k), x: (B, T, 1, in) or (B, T, k, in) -> output: (B, T, k, out)
|
|
return (x.unsqueeze(-2) @ self.weight[sel].transpose(-1, -2)).squeeze(-2)
|
|
|
|
def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor:
|
|
assert x.shape[-1] % 2 == 0
|
|
cos, sin = freqs_cis.reshape(1, 1, x.shape[2], -1).chunk(2, dim=-1)
|
|
x1, x2 = x.chunk(2, dim=-1)
|
|
return (x1 * cos - x2 * sin).cat(x2 * cos + x1 * sin, dim=-1)
|
|
|
|
class TransformerBlock:
|
|
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, head_dim:int, rope_theta:float,
|
|
max_context:int=0, qk_norm:int=0, num_experts:int=0, num_experts_per_tok:int=0):
|
|
self.n_heads = n_heads
|
|
self.n_kv_heads = n_kv_heads
|
|
self.head_dim = head_dim
|
|
self.rope_theta = rope_theta
|
|
self.max_context = max_context
|
|
self.qk_norm = qk_norm
|
|
|
|
# --- attention projections (all linear, bias-free) ------------------
|
|
q_proj_out = self.head_dim * n_heads
|
|
kv_proj_out = self.head_dim * n_kv_heads
|
|
self.attn_q = nn.Linear(dim, q_proj_out, bias=False)
|
|
self.attn_k = nn.Linear(dim, kv_proj_out, bias=False)
|
|
self.attn_v = nn.Linear(dim, kv_proj_out, bias=False)
|
|
self.attn_output = nn.Linear(q_proj_out, dim, bias=False)
|
|
|
|
# --- RMSNorms --------------------------------------------------------
|
|
self.attn_norm = nn.RMSNorm(dim, norm_eps)
|
|
self.ffn_norm = nn.RMSNorm(dim, norm_eps)
|
|
if qk_norm: self.attn_q_norm, self.attn_k_norm = nn.RMSNorm(qk_norm, norm_eps), nn.RMSNorm(qk_norm, norm_eps)
|
|
|
|
# --- feed-forward (MoE or dense) -------------------------------------
|
|
if num_experts > 0:
|
|
self.num_experts_per_tok = num_experts_per_tok
|
|
self.ffn_gate_inp = nn.Linear(dim, num_experts, bias=False) # router
|
|
self.ffn_gate_exps = ExpertWeights(num_experts, dim, hidden_dim)
|
|
self.ffn_up_exps = ExpertWeights(num_experts, dim, hidden_dim)
|
|
self.ffn_down_exps = ExpertWeights(num_experts, hidden_dim, dim)
|
|
else:
|
|
self.ffn_gate = nn.Linear(dim, hidden_dim, bias=False)
|
|
self.ffn_up = nn.Linear(dim, hidden_dim, bias=False)
|
|
self.ffn_down = nn.Linear(hidden_dim, dim, bias=False)
|
|
|
|
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
|
|
x_norm = self.attn_norm(x) # (B,T,D)
|
|
q, k, v = self.attn_q(x_norm), self.attn_k(x_norm), self.attn_v(x_norm)
|
|
if self.qk_norm and self.qk_norm != self.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
|
|
|
|
B, T, _ = x.shape
|
|
q = q.reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B,H,T,Hd)
|
|
k = k.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
|
v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
|
if self.qk_norm == self.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
|
|
|
|
freqs_cis = precompute_freqs_cis(self.head_dim, self.max_context, self.rope_theta)[start_pos:start_pos+T]
|
|
q = apply_rope(q, freqs_cis)
|
|
k = apply_rope(k, freqs_cis)
|
|
|
|
if not hasattr(self, "cache_kv"):
|
|
self.cache_kv = Tensor.zeros(2, B, self.n_kv_heads, self.max_context, self.head_dim, dtype=k.dtype, device=k.device).contiguous().realize()
|
|
self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v))
|
|
k = self.cache_kv[0, :, :, 0:start_pos+T, :]
|
|
v = self.cache_kv[1, :, :, 0:start_pos+T, :]
|
|
|
|
# NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True
|
|
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(int(start_pos)+1) if T > 1 else None
|
|
attn = q.scaled_dot_product_attention(k, v, attn_mask=mask, enable_gqa=True) # (B,H,T,Hd)
|
|
attn = attn.transpose(1, 2).reshape(B, T, -1) # back to (B,T,D)
|
|
attn = self.attn_output(attn)
|
|
return x + attn
|
|
|
|
def _feed_forward(self, h: Tensor) -> Tensor:
|
|
h_norm = self.ffn_norm(h)
|
|
if hasattr(self, 'ffn_gate_exps'):
|
|
x = h_norm.unsqueeze(2) # (B, T, 1, D) - add expert dim for broadcasting
|
|
probs, sel = self.ffn_gate_inp(h_norm).softmax(-1).topk(self.num_experts_per_tok) # (B, T, k) each
|
|
x_down = self.ffn_down_exps(sel, self.ffn_gate_exps(sel, x).silu() * self.ffn_up_exps(sel, x)) # (B, T, k, D)
|
|
return h + (x_down * probs.unsqueeze(-1)).sum(axis=2) # (B, T, D)
|
|
# TODO: remove the need for this contiguous
|
|
gated = self.ffn_gate(h_norm).silu().contiguous() * self.ffn_up(h_norm)
|
|
return h + self.ffn_down(gated)
|
|
|
|
def __call__(self, x: Tensor, start_pos: int|UOp):
|
|
return self._feed_forward(self._attention(x, start_pos)).contiguous()
|
|
|
|
class Transformer:
|
|
def __init__(self, *, num_blocks, dim, hidden_dim, n_heads, n_kv_heads, norm_eps, vocab_size, head_dim:int, rope_theta:float,
|
|
max_context:int=0, qk_norm:int=0, num_experts:int=0, num_experts_per_tok:int=0):
|
|
self.blk = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, head_dim, rope_theta, max_context, qk_norm,
|
|
num_experts, num_experts_per_tok) for _ in range(num_blocks)]
|
|
self.token_embd = nn.Embedding(vocab_size, dim)
|
|
self.output_norm = nn.RMSNorm(dim, norm_eps)
|
|
self.output = nn.Linear(dim, vocab_size, bias=False)
|
|
self.max_context = max_context
|
|
# JIT is used if T=1 and start_pos is a UOp. TODO: make this not needed by including T in the JIT and making start_pos always a UOp
|
|
self.forward_jit = TinyJit(self.forward)
|
|
|
|
def forward(self, tokens:Tensor, start_pos:int|UOp) -> Tensor:
|
|
x = self.token_embd(tokens) # (B, T, D)
|
|
for block in self.blk: x = block(x, start_pos)
|
|
# TODO: add temperature
|
|
return self.output(self.output_norm(x))[:, -1, :].softmax(-1, dtype="float").argmax(-1, keepdim=True)
|
|
|
|
def __call__(self, tokens:Tensor, start_pos:int|UOp=0) -> Tensor:
|
|
return (self.forward_jit if getenv("JIT", 1) and tokens.shape[1] == 1 and isinstance(start_pos, UOp) else self.forward)(tokens, start_pos)
|
|
|
|
@staticmethod
|
|
def from_gguf(gguf:Tensor, max_context:int|None=None, realize=True) -> tuple[Transformer, dict]:
|
|
# TODO: remove the need for copy to default device
|
|
kv, state_dict = nn.state.gguf_load(gguf.to(None))
|
|
|
|
# all state items should be float16, not float32
|
|
state_dict = {k:v.cast('float16') if getenv("HALF", 1) else v for k,v in state_dict.items()}
|
|
|
|
# some models like Llama 3.2 don't have an output.weight, they just tie to the token_embd.weight
|
|
if 'output.weight' not in state_dict: state_dict['output.weight'] = state_dict['token_embd.weight']
|
|
|
|
arch = kv['general.architecture']
|
|
max_context = min(max_context, kv[f'{arch}.context_length']) if max_context is not None else kv[f'{arch}.context_length']
|
|
n_heads, n_kv_heads = kv[f'{arch}.attention.head_count'], kv[f'{arch}.attention.head_count_kv']
|
|
|
|
# Permute Q/K weights from interleaved to half-split RoPE layout (llama-style models only)
|
|
if arch == 'llama':
|
|
for name in state_dict:
|
|
if 'attn_q.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_heads, two=2)
|
|
if 'attn_k.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_kv_heads, two=2)
|
|
|
|
model = Transformer(num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'],
|
|
hidden_dim=kv.get(f'{arch}.expert_feed_forward_length', kv[f'{arch}.feed_forward_length']),
|
|
n_heads=n_heads, n_kv_heads=n_kv_heads, norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'],
|
|
vocab_size=len(kv['tokenizer.ggml.tokens']),
|
|
head_dim=kv.get(f'{arch}.attention.key_length', kv[f'{arch}.embedding_length'] // n_heads),
|
|
rope_theta=kv[f'{arch}.rope.freq_base'], max_context=max_context,
|
|
qk_norm=int(state_dict['blk.0.attn_q_norm.weight'].shape[0]) if 'blk.0.attn_q_norm.weight' in state_dict else 0,
|
|
num_experts=kv.get(f'{arch}.expert_count', 0), num_experts_per_tok=kv.get(f'{arch}.expert_used_count', 0))
|
|
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
|
|
# NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster
|
|
for s in (params:=nn.state.get_parameters(model)): s.replace(s.contiguous())
|
|
if realize: Tensor.realize(*params)
|
|
return model, kv
|
|
|
|
def generate(self, tokens:list[int], start_pos=0):
|
|
v_start_pos = UOp.variable("start_pos", 1, self.max_context-1)
|
|
t = Tensor([tokens[start_pos:]], dtype="int32")
|
|
while len(tokens) < self.max_context:
|
|
t = self(t, v_start_pos.bind(start_pos) if getenv("SYM", 1) and start_pos != 0 and t.shape[-1] == 1 else start_pos)
|
|
next_id = int(t.item())
|
|
tokens.append(next_id)
|
|
start_pos = len(tokens) - 1
|
|
yield next_id
|
|
|
|
models = {
|
|
"llama3.2:1b": "https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q6_K.gguf",
|
|
"llama3.2:1b-q4": "https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q4_K_M.gguf",
|
|
"llama3.2:3b": "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-Q6_K.gguf",
|
|
"llama3.2:3b-f16": "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-f16.gguf",
|
|
"llama3.1:8b": "https://huggingface.co/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/resolve/main/Meta-Llama-3.1-8B-Instruct-Q8_0.gguf",
|
|
"qwen3:0.6b": "https://huggingface.co/Qwen/Qwen3-0.6B-GGUF/resolve/main/Qwen3-0.6B-Q8_0.gguf",
|
|
"qwen3:1.7b": "https://huggingface.co/unsloth/Qwen3-1.7B-GGUF/resolve/main/Qwen3-1.7B-Q4_K_M.gguf",
|
|
"qwen3:8b": "https://huggingface.co/Qwen/Qwen3-8B-GGUF/resolve/main/Qwen3-8B-Q4_K_M.gguf",
|
|
"qwen3:30b-a3b": "https://huggingface.co/Qwen/Qwen3-30B-A3B-GGUF/resolve/main/Qwen3-30B-A3B-Q4_K_M.gguf",
|
|
"olmoe": "https://huggingface.co/allenai/OLMoE-1B-7B-0924-Instruct-GGUF/resolve/main/olmoe-1b-7b-0924-instruct-q4_k_m.gguf",
|
|
}
|
|
|
|
# *** simple OpenAI compatible server on 11434 to match ollama ***
|
|
# OPENAI_BASE_URL=http://localhost:11434/v1 OPENAI_API_KEY=ollama uvx --from gpt-command-line gpt
|
|
|
|
CHAT_HTML = b'''<!DOCTYPE html><html><head><title>tinygrad chat</title><style>
|
|
* { margin: 0 }
|
|
body { background: #212121; color: #e3e3e3; font-family: system-ui;
|
|
height: 100vh; display: flex; flex-direction: column }
|
|
#chat { flex: 1; overflow-y: auto; padding: 20px }
|
|
.msg { padding: 10px 16px; margin: 8px 0; white-space: pre-wrap; border-radius: 18px }
|
|
.user { background: #2f2f2f; margin-left: auto; width: fit-content; max-width: 70% }
|
|
#input { max-width: 768px; width: 100%; margin: 20px auto; padding: 14px 20px;
|
|
background: #2f2f2f; color: inherit; font: inherit;
|
|
border: none; outline: none; resize: none; border-radius: 24px; field-sizing: content }
|
|
</style></head><body><div id="chat"></div>
|
|
<textarea id="input" rows="1" placeholder="Ask anything"></textarea>
|
|
<script>
|
|
input.onkeydown = (e) => { if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault(); send() } }
|
|
const msgs = [];
|
|
async function send() {
|
|
if (!input.value.trim()) return;
|
|
msgs.push({role: 'user', content: input.value.trim()});
|
|
chat.innerHTML += '<div class="msg user">' + input.value.trim().replace(/</g, '<') + '</div>';
|
|
input.value = '';
|
|
const d = document.createElement('div'); d.className = 'msg'; chat.appendChild(d);
|
|
const r = await fetch('/v1/chat/completions', {method: 'POST', headers: {'Content-Type': 'application/json'},
|
|
body: JSON.stringify({model: 'llama', messages: msgs, stream: true})});
|
|
for (const rd = r.body.getReader(), dec = new TextDecoder();;) {
|
|
const {done, value} = await rd.read();
|
|
if (done) break;
|
|
for (const ln of dec.decode(value).split('\\n'))
|
|
if (ln.startsWith('data: ') && !ln.includes('[DONE]'))
|
|
try { d.textContent += JSON.parse(ln.slice(6)).choices[0]?.delta?.content || '' } catch {}
|
|
chat.scrollTop = chat.scrollHeight;
|
|
}
|
|
msgs.push({role: 'assistant', content: d.textContent});
|
|
}
|
|
</script></body></html>'''
|
|
|
|
class Handler(HTTPRequestHandler):
|
|
def log_request(self, code='-', size='-'): pass
|
|
def do_GET(self): self.send_data(CHAT_HTML, content_type="text/html")
|
|
def run_model(self, ids:list[int], model_name:str, include_usage=False):
|
|
stderr_log(f"{self.path} {colored('--', 'BLACK')} in:{len(ids):5d} {colored('--', 'BLACK')} ")
|
|
tmpl = {"id":f"chatcmpl-{uuid.uuid4().hex[:24]}", "object":"chat.completion.chunk", "created":int(time.time()), "model":model_name}
|
|
yield {"choices": [{"index":0, "delta":{"role":"assistant","content":""}, "finish_reason":None}], **tmpl}
|
|
out: list[int] = []
|
|
st = time.perf_counter()
|
|
for next_id in model.generate(ids):
|
|
if len(out) == 0: stderr_log(f"prefill:{len(ids)/((pt:=time.perf_counter())-st):4.0f} tok/s {colored('--', 'BLACK')} ")
|
|
if next_id == eos_id: break
|
|
out.append(next_id)
|
|
yield {"choices": [{"index":0, "delta":{"content":tok.decode([next_id])}, "finish_reason":None}], **tmpl}
|
|
yield {"choices": [{"index":0, "delta":{},"finish_reason":"stop"}], **tmpl}
|
|
if include_usage:
|
|
yield {"choices": [], "usage": {"prompt_tokens": len(ids), "completion_tokens": len(out), "total_tokens": len(ids) + len(out)}, **tmpl}
|
|
stderr_log(f"out:{len(out):5d} {colored('--', 'BLACK')} gen: {len(out)/(time.perf_counter()-pt):4.0f} tok/s\n")
|
|
|
|
def do_POST(self):
|
|
raw_body = self.rfile.read(int(self.headers.get("Content-Length", "0")))
|
|
body: dict[str, typing.Any] = json.loads(raw_body.decode("utf-8"))
|
|
if DEBUG >= 1: print(json.dumps(body, indent=2))
|
|
if self.path == "/v1/chat/completions":
|
|
# extract tokens
|
|
ids: list[int] = [bos_id] if bos_id is not None else []
|
|
for msg in body["messages"]:
|
|
ids += tok.role(msg["role"])
|
|
# content can be a str or a list
|
|
content = msg["content"]
|
|
if isinstance(content, str): ids += tok.encode(content)
|
|
elif isinstance(content, list):
|
|
for c in content:
|
|
if c["type"] == "text": ids += tok.encode(c["text"])
|
|
else: raise RuntimeError(f"unhandled type: {c['type']}")
|
|
else: raise RuntimeError(f"unknown content type: {type(content)}")
|
|
ids += tok.end_turn(eos_id)
|
|
ids += tok.role("assistant")
|
|
|
|
# reply
|
|
chunks = self.run_model(ids, body["model"], not body.get("stream") or body.get("stream_options",{}).get("include_usage", False))
|
|
if body.get("stream"): self.stream_json(chunks)
|
|
else:
|
|
out = []
|
|
for c in chunks: out.append(c["choices"][0]["delta"].get("content", "") if c["choices"] else "")
|
|
self.send_data(json.dumps({**c, "object":"chat.completion",
|
|
"choices":[{"index":0, "message":{"role":"assistant","content":"".join(out)}, "finish_reason":"stop"}]}).encode())
|
|
else:
|
|
raise RuntimeError(f"unhandled path {self.path}")
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--model", choices=list(models.keys()), default=list(models.keys())[0], help="Model choice")
|
|
parser.add_argument("--max_context", type=int, default=4096, help="Max Context Length")
|
|
parser.add_argument("--serve", nargs='?', type=int, const=11434, metavar="PORT", help="Run OpenAI compatible API (optional port, default 11434)")
|
|
parser.add_argument("--benchmark", nargs='?', type=int, const=20, metavar="COUNT", help="Benchmark tok/s (optional count, default 20)")
|
|
args = parser.parse_args()
|
|
|
|
# load the model
|
|
model, kv = Transformer.from_gguf(Tensor.from_url(models[args.model]), args.max_context)
|
|
if DEBUG >= 1: print(f"using model {args.model}")
|
|
|
|
# do benchmark
|
|
if args.benchmark:
|
|
param_bytes = sum(x.nbytes() for x in nn.state.get_parameters(model))
|
|
gen = model.generate([0], 0)
|
|
for _ in range(args.benchmark):
|
|
GlobalCounters.reset()
|
|
with Timing(on_exit=lambda x: f", {1e9/x:6.2f} tok/s, {GlobalCounters.global_mem/x:7.2f} GB/s, param {param_bytes/x:7.2f} GB/s"): next(gen)
|
|
exit(0)
|
|
|
|
# extract some metadata
|
|
tok = SimpleTokenizer.from_gguf_kv(kv)
|
|
bos_id: int|None = kv.get('tokenizer.ggml.bos_token_id') if kv.get('tokenizer.ggml.add_bos_token', True) else None
|
|
eos_id: int = kv['tokenizer.ggml.eos_token_id']
|
|
|
|
# start server
|
|
if args.serve: TCPServerWithReuse(('', args.serve), Handler).serve_forever()
|
|
|
|
ids: list[int] = [bos_id] if bos_id is not None else []
|
|
while 1:
|
|
start_pos = max(len(ids) - 1, 0)
|
|
try:
|
|
ids += tok.role("user") + tok.encode(input('>>> ')) + tok.end_turn(eos_id) + tok.role("assistant")
|
|
except EOFError:
|
|
break
|
|
for next_id in model.generate(ids, start_pos):
|
|
sys.stdout.write(tok.decode([next_id]) if next_id != eos_id else "\n\n")
|
|
sys.stdout.flush()
|
|
if next_id == eos_id: break
|