Files
onepilot/tinygrad_repo/tinygrad/codegen/late/regalloc.py
T
Vehicle Researcher 6adb63b915 openpilot v0.11.1 release
date: 2026-06-04T09:49:56
master commit: c0ab3550eca2e9daf197c46b7e4b24aa9637cf2e
2026-06-04 09:50:05 -07:00

138 lines
7.0 KiB
Python

import itertools
from tinygrad.helpers import dedup
from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat
from tinygrad.renderer.isa import ISARenderer, Register
from tinygrad.dtype import dtypes, PtrDType
PSEUDO_OPS = {Ops.CONST, Ops.NOOP, Ops.AFTER, Ops.BARRIER, Ops.GROUP}
class LinearScanRegallocContext:
# returns the uop that defines the virtual register
def vdef(self, v:Register) -> UOp: return self.uops[self.live_range[v][0]]
def __init__(self, uops:list[UOp], ren:ISARenderer):
self.uops = uops
self.ren = ren
self.idx = itertools.count()
# the label associated with each loop NOTE: this is only used post regalloc and should be removed
self.loop_label: dict[UOp, str] = {}
# compute live ranges
self.live_range: dict[Register, list[int]] = {}
lr = self.live_range
ranges: list[Register] = []
for i,u in enumerate(reversed(uops)):
if u.op in PSEUDO_OPS: continue
defs = u.tag if isinstance(u.tag, tuple) else ()
for v in defs + tuple(s.reg for s in dedup(u.src)):
if isinstance(v, Register): lr.setdefault(v, []).insert(0, len(uops) - 1 - i)
for v in defs:
if v in lr and (n:=max((lr[rng][-1] for rng in ranges if lr[rng][0] <= lr[v][-1] < lr[rng][-1]), default=None)): lr[v].append(n)
if u.op is Ops.RANGE: ranges.append(u.reg)
# allocate registers
self.stack_size: int = 0
self.locals: dict[UOp, UOp] = {}
self.spills: dict[Register, UOp] = {} # mapping from virtual to stack slot
self.reals: dict[int, dict[Register, Register]] = {} # mapping from virtual to real at each program point
self.insert_before: dict[int, list[tuple[Register, Register]]] = {} # fills to be inserted at each program point
live: dict[Register, Register] = {} # mapping from virtual to real that's currently assigned to it
live_ins: list[dict[Register, Register]] = [] # mapping from virtual to real at loop entry
def alloc(cons:tuple[Register, ...], i:int) -> Register:
live_inv = {v:k for k,v in live.items()}
# allocate the best register. Registers not in live or not used again are free and have priority,
# otherwise pick the one with the furthest next use. Regs that appear first in cons have priority in case of a tie
reg,vreg = max(((r,live_inv.get(r)) for r in cons),
key=lambda rv: next((j-i for j in ([] if rv[1] is None else lr[rv[1]]) if j >= i), len(uops)))
return live.pop(vreg) if vreg is not None else reg
# assign register to spilled virtual and record load to be emitted before current uop, also assign it a stack slot
def fill(v:Register, i:int, cons:tuple[Register, ...]|None=None) -> Register:
if v not in self.spills:
dt = self.vdef(v).dtype
sz = dt.scalar().itemsize * dt.count if not isinstance(dt, PtrDType) else 8
offset = self.stack_size + (sz - self.stack_size % sz) % sz
self.spills[v] = UOp.const(dtypes.int32, offset)
self.stack_size = offset + sz
r = alloc(cons if cons is not None else v.cons, i)
self.insert_before.setdefault(i, []).append((v, r))
return r
for i,u in enumerate(uops):
if u.op in PSEUDO_OPS: continue
# allocate uses
for s in u.src:
# HACK: cause of later hacks to lower range
if u.op is Ops.END: continue
if not isinstance(v:=s.reg, Register): continue
if v not in live: live[v] = fill(v, i)
self.reals.setdefault(i, {})[v] = live[v]
# allocate defs
if isinstance(u.tag, tuple):
for j,v in enumerate(u.tag):
# register should only be defined once
assert isinstance(v, Register) and lr[v][0] == i
cons = v.cons
# two address instructions (src is reused by def) can only coalesce reused src. reused src goes first to get priority in case of a tiebreak
if ren.is_two_address(u) and j == 0:
uses = tuple(live.get(s.reg) for s in u.src)
cons = ((uses[0],) if uses[0] in cons else ()) + tuple(r for r in cons if r not in uses)
# HACK: cause the range is missing the comparison
live[v] = alloc(cons, i+1 if u.op is not Ops.RANGE else i)
self.reals.setdefault(i, {})[v] = live[v]
# allocate stack array
if u.op is Ops.DEFINE_LOCAL:
self.locals[u] = UOp.const(dtypes.int32, self.stack_size)
self.stack_size += u.dtype.nbytes()
# loop prologue, avoid loading inside the loop
if u.op is Ops.RANGE:
# we move to registers vars used in the loop sorted by next use, vars not used in the loop will not be reloaded in the epilogue
used_in_loop = [v for v in live.keys() | self.spills.keys() if any(i <= l < lr[u.reg][-1] for l in lr[v])]
sorted_uses = sorted(used_in_loop, key=lambda k: (next(l-i for l in lr[k] if l >= i), lr[k][0], k.name, k.index))
live_in: dict[Register, Register] = {}
for v in sorted_uses:
# if all the possible registers are already in live_in there's no space for this var
if set(v.cons).issubset(live_in.values()): continue
if v not in live: live[v] = fill(v, i)
live_in[v] = live[v]
live_ins.append(live_in)
# loop epilogue, reload registers that were live at loop entry
if u.op is Ops.END:
# TODO: if a uop is in a different reg in live out vs live in move between registers instead of loading
# TODO: don't reload if first use in loop is a load
for v,r in live_ins.pop().items():
if v not in live or live[v] != r: live[v] = fill(v, i, (r,))
def regalloc_rewrite(ctx:LinearScanRegallocContext, x:UOp):
i = next(ctx.idx)
if x.op in PSEUDO_OPS: return None
nsrc = []
for j,s in enumerate(x.src):
# v here is the virtual defined by the original s as s is the rewritten version
if i in ctx.reals and (v:=ctx.uops[i].src[j].reg) in ctx.spills: nsrc.append(ctx.ren.fill(ctx.spills[v], ctx.vdef(v), ctx.reals[i][v]))
else: nsrc.append(s)
ndefs = tuple(ctx.reals[i][v] for v in x.tag) if isinstance(x.tag, tuple) else x.tag
if x.op is Ops.DEFINE_LOCAL: nx = ctx.ren.isel_matcher.rewrite(ctx.ren.stack_pointer().index(ctx.locals[x], dtype=x.dtype, tag=ndefs))
else: nx = x.replace(src=tuple(nsrc), tag=ndefs)
before = [ctx.ren.fill(ctx.spills[v], ctx.vdef(v), r) for v,r in ctx.insert_before.get(i, [])]
after = [ctx.ren.spill(ctx.spills[v], nx) for v in x.tag if v in ctx.spills] if isinstance(x.tag, tuple) else []
# alloc/dealloc stack
if ctx.stack_size > 0:
sp = ctx.ren.stack_pointer()
offset = UOp(Ops.CONST, sp.dtype, arg=ctx.stack_size)
if i == 0: before = [ctx.ren.isel_matcher.rewrite(UOp(Ops.SUB, sp.dtype, (sp, offset), tag=sp.tag))] + before
elif i == len(ctx.uops) - 2: before += [ctx.ren.isel_matcher.rewrite(UOp(Ops.ADD, sp.dtype, (sp, offset), tag=sp.tag))]
return nx, before + [nx] + after
pm_regalloc_rewrite = PatternMatcher([
(UPat({Ops.INS, Ops.RANGE, Ops.END, Ops.DEFINE_REG, Ops.DEFINE_LOCAL, Ops.PARAM, Ops.DEFINE_VAR, Ops.SPECIAL} | PSEUDO_OPS, name="x"),
regalloc_rewrite),
])