6adb63b915
date: 2026-06-04T09:49:56 master commit: c0ab3550eca2e9daf197c46b7e4b24aa9637cf2e
138 lines
7.0 KiB
Python
138 lines
7.0 KiB
Python
import itertools
|
|
from tinygrad.helpers import dedup
|
|
from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat
|
|
from tinygrad.renderer.isa import ISARenderer, Register
|
|
from tinygrad.dtype import dtypes, PtrDType
|
|
|
|
PSEUDO_OPS = {Ops.CONST, Ops.NOOP, Ops.AFTER, Ops.BARRIER, Ops.GROUP}
|
|
|
|
class LinearScanRegallocContext:
|
|
# returns the uop that defines the virtual register
|
|
def vdef(self, v:Register) -> UOp: return self.uops[self.live_range[v][0]]
|
|
def __init__(self, uops:list[UOp], ren:ISARenderer):
|
|
self.uops = uops
|
|
self.ren = ren
|
|
self.idx = itertools.count()
|
|
# the label associated with each loop NOTE: this is only used post regalloc and should be removed
|
|
self.loop_label: dict[UOp, str] = {}
|
|
|
|
# compute live ranges
|
|
self.live_range: dict[Register, list[int]] = {}
|
|
lr = self.live_range
|
|
ranges: list[Register] = []
|
|
for i,u in enumerate(reversed(uops)):
|
|
if u.op in PSEUDO_OPS: continue
|
|
defs = u.tag if isinstance(u.tag, tuple) else ()
|
|
for v in defs + tuple(s.reg for s in dedup(u.src)):
|
|
if isinstance(v, Register): lr.setdefault(v, []).insert(0, len(uops) - 1 - i)
|
|
for v in defs:
|
|
if v in lr and (n:=max((lr[rng][-1] for rng in ranges if lr[rng][0] <= lr[v][-1] < lr[rng][-1]), default=None)): lr[v].append(n)
|
|
if u.op is Ops.RANGE: ranges.append(u.reg)
|
|
|
|
# allocate registers
|
|
self.stack_size: int = 0
|
|
self.locals: dict[UOp, UOp] = {}
|
|
self.spills: dict[Register, UOp] = {} # mapping from virtual to stack slot
|
|
self.reals: dict[int, dict[Register, Register]] = {} # mapping from virtual to real at each program point
|
|
self.insert_before: dict[int, list[tuple[Register, Register]]] = {} # fills to be inserted at each program point
|
|
live: dict[Register, Register] = {} # mapping from virtual to real that's currently assigned to it
|
|
live_ins: list[dict[Register, Register]] = [] # mapping from virtual to real at loop entry
|
|
|
|
def alloc(cons:tuple[Register, ...], i:int) -> Register:
|
|
live_inv = {v:k for k,v in live.items()}
|
|
# allocate the best register. Registers not in live or not used again are free and have priority,
|
|
# otherwise pick the one with the furthest next use. Regs that appear first in cons have priority in case of a tie
|
|
reg,vreg = max(((r,live_inv.get(r)) for r in cons),
|
|
key=lambda rv: next((j-i for j in ([] if rv[1] is None else lr[rv[1]]) if j >= i), len(uops)))
|
|
return live.pop(vreg) if vreg is not None else reg
|
|
|
|
# assign register to spilled virtual and record load to be emitted before current uop, also assign it a stack slot
|
|
def fill(v:Register, i:int, cons:tuple[Register, ...]|None=None) -> Register:
|
|
if v not in self.spills:
|
|
dt = self.vdef(v).dtype
|
|
sz = dt.scalar().itemsize * dt.count if not isinstance(dt, PtrDType) else 8
|
|
offset = self.stack_size + (sz - self.stack_size % sz) % sz
|
|
self.spills[v] = UOp.const(dtypes.int32, offset)
|
|
self.stack_size = offset + sz
|
|
r = alloc(cons if cons is not None else v.cons, i)
|
|
self.insert_before.setdefault(i, []).append((v, r))
|
|
return r
|
|
|
|
for i,u in enumerate(uops):
|
|
if u.op in PSEUDO_OPS: continue
|
|
# allocate uses
|
|
for s in u.src:
|
|
# HACK: cause of later hacks to lower range
|
|
if u.op is Ops.END: continue
|
|
if not isinstance(v:=s.reg, Register): continue
|
|
if v not in live: live[v] = fill(v, i)
|
|
self.reals.setdefault(i, {})[v] = live[v]
|
|
|
|
# allocate defs
|
|
if isinstance(u.tag, tuple):
|
|
for j,v in enumerate(u.tag):
|
|
# register should only be defined once
|
|
assert isinstance(v, Register) and lr[v][0] == i
|
|
cons = v.cons
|
|
# two address instructions (src is reused by def) can only coalesce reused src. reused src goes first to get priority in case of a tiebreak
|
|
if ren.is_two_address(u) and j == 0:
|
|
uses = tuple(live.get(s.reg) for s in u.src)
|
|
cons = ((uses[0],) if uses[0] in cons else ()) + tuple(r for r in cons if r not in uses)
|
|
# HACK: cause the range is missing the comparison
|
|
live[v] = alloc(cons, i+1 if u.op is not Ops.RANGE else i)
|
|
self.reals.setdefault(i, {})[v] = live[v]
|
|
|
|
# allocate stack array
|
|
if u.op is Ops.DEFINE_LOCAL:
|
|
self.locals[u] = UOp.const(dtypes.int32, self.stack_size)
|
|
self.stack_size += u.dtype.nbytes()
|
|
|
|
# loop prologue, avoid loading inside the loop
|
|
if u.op is Ops.RANGE:
|
|
# we move to registers vars used in the loop sorted by next use, vars not used in the loop will not be reloaded in the epilogue
|
|
used_in_loop = [v for v in live.keys() | self.spills.keys() if any(i <= l < lr[u.reg][-1] for l in lr[v])]
|
|
sorted_uses = sorted(used_in_loop, key=lambda k: (next(l-i for l in lr[k] if l >= i), lr[k][0], k.name, k.index))
|
|
live_in: dict[Register, Register] = {}
|
|
for v in sorted_uses:
|
|
# if all the possible registers are already in live_in there's no space for this var
|
|
if set(v.cons).issubset(live_in.values()): continue
|
|
if v not in live: live[v] = fill(v, i)
|
|
live_in[v] = live[v]
|
|
live_ins.append(live_in)
|
|
|
|
# loop epilogue, reload registers that were live at loop entry
|
|
if u.op is Ops.END:
|
|
# TODO: if a uop is in a different reg in live out vs live in move between registers instead of loading
|
|
# TODO: don't reload if first use in loop is a load
|
|
for v,r in live_ins.pop().items():
|
|
if v not in live or live[v] != r: live[v] = fill(v, i, (r,))
|
|
|
|
def regalloc_rewrite(ctx:LinearScanRegallocContext, x:UOp):
|
|
i = next(ctx.idx)
|
|
if x.op in PSEUDO_OPS: return None
|
|
nsrc = []
|
|
for j,s in enumerate(x.src):
|
|
# v here is the virtual defined by the original s as s is the rewritten version
|
|
if i in ctx.reals and (v:=ctx.uops[i].src[j].reg) in ctx.spills: nsrc.append(ctx.ren.fill(ctx.spills[v], ctx.vdef(v), ctx.reals[i][v]))
|
|
else: nsrc.append(s)
|
|
ndefs = tuple(ctx.reals[i][v] for v in x.tag) if isinstance(x.tag, tuple) else x.tag
|
|
if x.op is Ops.DEFINE_LOCAL: nx = ctx.ren.isel_matcher.rewrite(ctx.ren.stack_pointer().index(ctx.locals[x], dtype=x.dtype, tag=ndefs))
|
|
else: nx = x.replace(src=tuple(nsrc), tag=ndefs)
|
|
|
|
before = [ctx.ren.fill(ctx.spills[v], ctx.vdef(v), r) for v,r in ctx.insert_before.get(i, [])]
|
|
after = [ctx.ren.spill(ctx.spills[v], nx) for v in x.tag if v in ctx.spills] if isinstance(x.tag, tuple) else []
|
|
|
|
# alloc/dealloc stack
|
|
if ctx.stack_size > 0:
|
|
sp = ctx.ren.stack_pointer()
|
|
offset = UOp(Ops.CONST, sp.dtype, arg=ctx.stack_size)
|
|
if i == 0: before = [ctx.ren.isel_matcher.rewrite(UOp(Ops.SUB, sp.dtype, (sp, offset), tag=sp.tag))] + before
|
|
elif i == len(ctx.uops) - 2: before += [ctx.ren.isel_matcher.rewrite(UOp(Ops.ADD, sp.dtype, (sp, offset), tag=sp.tag))]
|
|
|
|
return nx, before + [nx] + after
|
|
|
|
pm_regalloc_rewrite = PatternMatcher([
|
|
(UPat({Ops.INS, Ops.RANGE, Ops.END, Ops.DEFINE_REG, Ops.DEFINE_LOCAL, Ops.PARAM, Ops.DEFINE_VAR, Ops.SPECIAL} | PSEUDO_OPS, name="x"),
|
|
regalloc_rewrite),
|
|
])
|