more work parsing SQTT, separate VIZ/PROFILE (#13308)

* more work parsing SQTT

* more minimal runner

* sep VIZ/PROFILE

* parse print new

* improve parser

* more filter

* that

* split them

* lil cleanup

* skip flaky test

* AQL in mmapeak
This commit is contained in:
George Hotz
2025-11-16 10:40:39 -08:00
committed by GitHub
parent 13efdf8c31
commit cabd4add48
9 changed files with 344 additions and 118 deletions

View File

@@ -1,8 +1,10 @@
import pathlib
import os, pathlib
# TODO: there is a timing bug without this
os.environ["AMD_AQL"] = "1"
from tinygrad.device import Device
from tinygrad.runtime.ops_amd import AMDProgram, HIPCompiler
import time
import os
NUM_WORKGROUPS = 96
WAVE_SIZE = 32
@@ -44,9 +46,9 @@ if __name__=="__main__":
raise RuntimeError("Error while initiating AMD device")
COMPILER = HIPCompiler(DEV.arch)
if DEV.arch in {'gfx1100', 'gfx1103'}:
if DEV.arch == 'gfx1103':
NUM_WORKGROUPS = 8
if DEV.arch in {'gfx1100', 'gfx1103', 'gfx1151'}:
if DEV.arch == 'gfx1103': NUM_WORKGROUPS = 8
if DEV.arch == 'gfx1151': NUM_WORKGROUPS = 40
launchBenchmark("v_wmma_bf16_16x16x16_bf16", (7,8,15))
launchBenchmark("v_wmma_f16_16x16x16_f16", (7,8,15))
launchBenchmark("v_wmma_f32_16x16x16_bf16", (7,8,15))

View File

@@ -0,0 +1,99 @@
import os
os.environ["PYTHONPATH"] = "."
os.environ["SQTT"] = "1"
if "DEV" not in os.environ: os.environ["DEV"] = "AMD"
os.environ["PROFILE"] = "1"
os.environ["AMD_LLVM"] = "0"
from dataclasses import replace
import atexit, contextlib
from tinygrad.helpers import system, getenv
from tinygrad.runtime.ops_amd import AMDProgram
from extra.sqtt.roc import decode, WaveExec, ProfileSQTTEvent
from tinygrad.device import Device, ProfileDeviceEvent
from extra.sqtt.attempt_sqtt_parse import parse_sqtt_print_packets
def set_power(x): system(f"sudo /opt/rocm/bin/amd-smi set -l {x}")
@atexit.register
def reset_power(): set_power("auto")
set_power("stable_std")
dev = Device["AMD"]
@contextlib.contextmanager
def save_sqtt():
# clear the old traces
dev.profile_events.clear()
sqtt:dict[str, list[WaveExec]] = {}
yield sqtt
events = dev.profile_events+[ProfileDeviceEvent("AMD", props=dev.device_props())]
rctx = decode(events)
assert len(rctx.inst_execs) > 0, "empty sqtt output"
sqtt.update(rctx.inst_execs)
for e in events:
if isinstance(e, ProfileSQTTEvent):
print(replace(e, blob=b''))
if e.se == 0:
parse_sqtt_print_packets(e.blob, filter=[0xf, 0x11, 0x12, 0x14] if getenv("FILTER", 1) else None)
template = """.text
.globl matmul
.p2align 8
.type matmul,@function
matmul:
INSTRUCTION
s_endpgm
.rodata
.p2align 6
.amdhsa_kernel matmul
.amdhsa_next_free_vgpr .amdgcn.next_free_vgpr
.amdhsa_next_free_sgpr .amdgcn.next_free_sgpr
.amdhsa_wavefront_size32 1
.end_amdhsa_kernel
.amdgpu_metadata
---
amdhsa.version:
- 1
- 0
amdhsa.kernels:
- .name: matmul
.symbol: matmul.kd
.kernarg_segment_size: 0
.group_segment_fixed_size: 0
.private_segment_fixed_size: 0
.kernarg_segment_align: 4
.wavefront_size: 32
.sgpr_count: 8
.vgpr_count: 32
.max_flat_workgroup_size: 1024
...
.end_amdgpu_metadata
"""
def run_asm(src):
NUM_WORKGROUPS = 1
WAVE_SIZE = 32
NUM_WAVES = 1
lib = dev.compiler.compile(template.replace("INSTRUCTION", '\n'.join(src)))
dev.compiler.disassemble(lib)
fxn = AMDProgram(dev, "matmul", lib)
fxn(global_size=(NUM_WORKGROUPS,1,1), local_size=(WAVE_SIZE*NUM_WAVES,1,1), wait=True)
if __name__ == "__main__":
with save_sqtt() as sqtt:
run_asm([
#"v_rcp_f32 v1, v0"
"v_add_f32_e32 v1 v0 v0",
"v_add_f32_e32 v3 v2 v2",
"v_add_f32_e32 v5 v4 v4",
"v_add_f32_e32 v7 v6 v6",
#"v_add_f32_e32 v1 v0 v0",
#"v_add_f32_e32 v2 v1 v1",
#"s_nop 1"
]*1)

View File

@@ -1,44 +1,48 @@
import pickle
from hexdump import hexdump
from extra.sqtt.roc import decode, ProfileSQTTEvent
from tinygrad.helpers import getenv
# Instruction packets (one per ISA op)
# NOTE: these are bad guesses and may be wrong! feel free to update if you know better
OPCODE_NAMES = {
# Small metadata / structural packets (NOT ISA op kinds)
0x01: "META_SMALL_ID", # 12-bit identifier / slot tag
0x02: "META_FLAG", # 1-byte flag/mode (CF/AF/8F/DF...)
0x03: "META_SUBEVENT_CODE", # 1-byte sub-event/classification code
0x04: "META_BASE_INDEX_TAG", # 12-bit base index/tag (..D, 9D, 10D, 58D...)
# ------------------------------------------------------------------------
# 0x010x06: small “meta + maybe tiny delta” packets
# ------------------------------------------------------------------------
0x01: "META_ID12_TS_SMALL", # 12-bit ID + 3-bit delta field
0x02: "META_FLAG8_TS_SMALL", # 8-bit flag/mode + small delta
0x03: "META_SUBEVENT8_TS_SMALL", # 8-bit subevent/class + small delta
0x04: "META_BASE_INDEX12_TS", # 12-bit base index + small delta
0x05: "META_DESC24_TS_A", # 24-bit descriptor-ish + delta field
0x06: "META_DESC24_TS_B", # second flavour, 24-bit, delta field
# Instruction / timing / timestamp packets
0x0F: "TIME_SHORT_DELTA_PLUS4", # short ts, raw_delta+4
0x11: "TIME_WAVE_STATE", # compact wave timing/stall state record
0x14: "INST_EXEC_RECORD", # per-instruction execution record
0x16: "TIME_LONG_OR_MARKER", # long delta / marker with 6-byte payload
# ------------------------------------------------------------------------
# 0x070x0F: pure timestamp-ish deltas
# ------------------------------------------------------------------------
0x07: "TS_DELTA_S8_W3", # shift=8, width=3 (small delta)
0x08: "EVT_MATCH_SMALL", # event-ish, see fields below
0x09: "PERF_ROUTE_CONFIG", # routing/indirection config
0x0A: "TS_DELTA_S5_W2_A", # shift=5, width=2
0x0B: "TS_DELTA_S5_W3_A", # shift=5, width=3
0x0C: "TS_DELTA_S5_W3_B", # shift=5, width=3 (different consumer)
0x0D: "TS_DELTA_S5_W3_C", # shift=5, width=3
0x0E: "TS_DELTA_S7_W2", # shift=7, width=2
0x0F: "TS_DELTA_SHORT_PLUS4", # short delta; ROCm adds +4 before accumulate
# State / control / perf snapshots
0x09: "CONTROL_CONFIG_32B", # 32-bit control/config word (bursts of FE88..., C488...)
0x15: "PERFCOUNTER_SNAPSHOT", # perf / TT configuration snapshot (8-byte)
# ------------------------------------------------------------------------
# 0x100x19: timestamps, layout headers, events, perf
# ------------------------------------------------------------------------
0x10: "PSEUDO_NEED_MORE_BITS", # not a real packet; decoder refill hint
# Extra descriptors / events / metrics
0x06: "META_DESCRIPTOR_24B", # 24-bit descriptor (seen in complex kernels like GEMM)
0x08: "EVENT_SMALL", # small in-stream event (5-nibble payload)
0x12: "TIME_SECONDARY_METRIC", # 3-byte secondary timing/latency/perf metric
0x18: "EVENT_SMALL_PAYLOAD", # generic small side-band payload (5 nibbles)
0x19: "EVENT_SUMMARY_48B", # rare 6-byte summary/aggregate metric
0x11: "TS_WAVE_STATE_SAMPLE", # wave stall/termination sample (byte at +10)
0x12: "EVT_SECONDARY_METRIC24", # 24-bit secondary timing/perf metric
0x13: "EVT_SMALL_GENERIC", # same structural family as 0x08/0x12/0x19
# Pseudo / unknown / not yet observed
0x07: "UNK_DELTA", # unknown
0x0A: "UNK_DELTA2", # unknown
0x0B: "UNK_DELTA3", # unknown
0x0C: "UNK_DELTA4", # unknown
0x0D: "UNK_DELTA5", # unknown
0x0E: "UNK_DELTA6", # unknown
0x10: "UNK_PSEUDO", # not seen; pseudo/placeholder
0x17: "UNK_NO_DELTA", # unknown, likely non-timing event
0x14: "INST_EXEC_OR_CFG", # instruction exec record / config write / COR marker
0x15: "PERFCOUNTER_SNAPSHOT", # small delta + 50-ish bits of snapshot
0x16: "TS_DELTA36_OR_MARK", # 36-bit long delta or 36-bit marker
0x17: "LAYOUT_MODE_HEADER", # layout/mode/group + selectors A/B
0x18: "PERF_EVENT_SELECT", # packed selector → FUN_0010aba0
0x19: "EVT_SUMMARY_48B", # 6-byte summary/aggregate metric
}
# these tables are from rocprof trace decoder
@@ -113,17 +117,13 @@ DELTA_MAP_DEFAULT = {
def decode_packet_fields(opcode: int, reg: int, delta: int) -> str:
"""
Conservative decoding of a few packet types.
Rules:
- We first mask the 64-bit shift register down to the actual packet
width using NIBBLE_BUDGET[opcode & 0x1F], so we never read bits
that aren't really part of the packet.
- Only layouts that are clearly visible from the decompiled C are
decoded, and names are kept generic (cfg_*, idx_*, id_*, etc).
Decode packet payloads conservatively, using:
- NIBBLE_BUDGET[opcode & 0x1F] to mask reg down to true width.
- DELTA_MAP_DEFAULT[opcode] to expose the "primary" field (often delta).
- Per-opcode layouts derived from rocprof's decompiled consumers.
"""
# --- 0. Restrict to the real packet bits for this opcode -------------
nb_bits = NIBBLE_BUDGET[opcode & 0x1F] # this table is in bits
# --- 0. Restrict to real packet bits ---------------------------------
nb_bits = NIBBLE_BUDGET[opcode & 0x1F]
if nb_bits <= 0 or nb_bits >= 64:
pkt = reg & ((1 << 64) - 1)
else:
@@ -131,23 +131,46 @@ def decode_packet_fields(opcode: int, reg: int, delta: int) -> str:
fields: list[str] = []
# --- 1. Timestamp-ish opcodes ----------------------------------------
shift, width = DELTA_MAP_DEFAULT.get(opcode, (0, 0))
if width:
field_mask = (1 << width) - 1
shaped_field = (pkt >> shift) & field_mask
else:
field_mask = 0
shaped_field = 0
if opcode == 0x0F: # TIME_SHORT_DELTA_PLUS4
# By the time we get here, `delta` is already raw_delta+4.
# =====================================================================
# 1. Timestamp-centric opcodes (actually drive 'time')
# =====================================================================
if opcode == 0x0F: # TS_DELTA_SHORT_PLUS4
# In the caller, delta already has +4 applied.
raw_delta = shaped_field
fields.append(f"raw_delta={raw_delta}")
fields.append(f"ts_short_plus4={delta}")
return ", ".join(fields)
if opcode == 0x11: # TIME_WAVE_STATE (medium/large delta)
shift, width = DELTA_MAP_DEFAULT[opcode]
raw_delta = (pkt >> shift) & ((1 << width) - 1)
coarse = (pkt >> (shift + width)) & 0xFF # next byte above delta
if opcode == 0x11: # TS_WAVE_STATE_SAMPLE
# DELTA_MAP_DEFAULT: shift=7, width=9 -> small delta.
raw_delta = shaped_field
coarse = (pkt >> (shift + width)) & 0xFF # matches byte at +10 in C
fields.append(f"raw_delta={raw_delta}")
if coarse:
fields.append(f"raw_coarse=0x{coarse:02x}")
fields.append(f"coarse_state=0x{coarse:02x}")
# From decomp:
# - when layout<3 and coarse&1, it sets a "has interesting wave" flag
# - when coarse&8, it marks all live waves as "terminated"
if coarse & 0x01:
fields.append("flag_wave_interest=1")
if coarse & 0x08:
fields.append("flag_terminate_all=1")
return ", ".join(fields)
if opcode == 0x16: # TIME_LONG_OR_MARKER
if opcode == 0x16: # TS_DELTA36_OR_MARK
# Bits:
# bit8 -> 0x100
# bit9 -> 0x200
# bits 12..47 -> 36-bit field used as delta or marker
bit8 = bool(pkt & 0x100)
bit9 = bool(pkt & 0x200)
if not bit9:
@@ -159,40 +182,95 @@ def decode_packet_fields(opcode: int, reg: int, delta: int) -> str:
val36 = (pkt >> 12) & ((1 << 36) - 1)
fields.append(f"mode={mode}")
fields.append(f"val36=0x{val36:x}")
if mode == "delta":
fields.append(f"delta36={delta}")
return ", ".join(fields)
# --- 2. Opcode 0x14: exec/config record ------------------------------
# For 0x07, 0x0A0x0E, we know they drive time (via DELTA_MAP_DEFAULT),
# but we don't see any other fields used in the decomp.
if opcode in (0x07, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E):
if width:
raw_delta = shaped_field
leftover = pkt & ~(field_mask << shift)
fields.append(f"raw_delta={raw_delta}")
if leftover:
fields.append(f"payload=0x{leftover:x}")
return ", ".join(fields)
if opcode == 0x14:
subop = (pkt >> 16) & 0xFFFF # matches (short)(w >> 0x10)
val32 = (pkt >> 32) & 0xFFFFFFFF # matches (uint)(w >> 0x20)
slot = (pkt >> 7) & 0x7 # used as (idx & 4) + (idx & 3)
hi_byte = (pkt >> 8) & 0xFF
# =====================================================================
# 2. Small "meta + tiny delta" packets (0x010x06)
# =====================================================================
if opcode == 0x01: # META_ID12_TS_SMALL
id12 = pkt & 0xFFF
fields.append(f"id12=0x{id12:03x}")
if width:
fields.append(f"field_s{shift}_w{width}={shaped_field}")
return ", ".join(fields)
if opcode == 0x02: # META_FLAG8_TS_SMALL
flag8 = pkt & 0xFF
fields.append(f"flag8=0x{flag8:02x}")
if width:
fields.append(f"field_s{shift}_w{width}={shaped_field}")
return ", ".join(fields)
if opcode == 0x03: # META_SUBEVENT8_TS_SMALL
sub8 = pkt & 0xFF
fields.append(f"subevent8=0x{sub8:02x}")
if width:
fields.append(f"field_s{shift}_w{width}={shaped_field}")
return ", ".join(fields)
if opcode == 0x04: # META_BASE_INDEX12_TS
idx12 = pkt & 0xFFF
fields.append(f"base_index12=0x{idx12:03x}")
if width:
fields.append(f"field_s{shift}_w{width}={shaped_field}")
return ", ".join(fields)
if opcode in (0x05, 0x06): # META_DESC24_TS_A/B
desc24 = pkt & 0xFFFFFF
fields.append(f"desc24=0x{desc24:06x}")
if width:
fields.append(f"field_s{shift}_w{width}={shaped_field}")
return ", ".join(fields)
# =====================================================================
# 3. Opcode 0x14: exec/config record (+ COR marker)
# =====================================================================
if opcode == 0x14: # INST_EXEC_OR_CFG
subop = (pkt >> 16) & 0xFFFF # (short)(w >> 0x10)
val32 = (pkt >> 32) & 0xFFFFFFFF # (uint)(w >> 0x20)
slot = (pkt >> 7) & 0x7 # index in local_168[...] tables
hi_byte = (pkt >> 8) & 0xFF # determines config vs marker
fields.append(f"subop=0x{subop:04x}")
fields.append(f"slot={slot}")
fields.append(f"val32=0x{val32:08x}")
if hi_byte & 0x80:
# "config" flavour, writes into local_168[...] etc.
# Config flavour: writes config words into per-slot state arrays.
fields.append("kind=config")
if subop == 0x000C:
fields.append("cfg_target=local_168[slot].lo")
elif subop == 0x000D:
fields.append("cfg_target=local_168[slot].hi")
else:
# COR marker: subop 0xC342, val32==0x434F5200 ("COR\0")
# COR marker: subop 0xC342, payload "COR\0" → start of a COR region.
if subop == 0xC342:
fields.append("kind=cor_stream")
if val32 == 0x434F5200:
fields.append("cor_magic='COR\\0'")
return ", ".join(fields)
# --- 3. Opcode 0x17: mode/layout header ------------------------------
# =====================================================================
# 4. Opcode 0x17: layout / mode header
# =====================================================================
if opcode == 0x17:
# From case 0x17:
if opcode == 0x17: # LAYOUT_MODE_HEADER
# From decomp (two sites with identical logic):
# layout = (w >> 7) & 0x3f
# mode = (w >> 0xd) & 3
# group = (w >> 0xf) & 7
@@ -213,27 +291,26 @@ def decode_packet_fields(opcode: int, reg: int, delta: int) -> str:
fields.append(f"sel_b={sel_b}")
if layout == 4:
fields.append(f"layout4_flag={flag4}")
return ", ".join(fields)
# --- 4. Opcode 0x09: state-ish / indirection record ------------------
# =====================================================================
# 5. Opcode 0x09: state / route config record
# =====================================================================
if opcode == 0x09:
# From case 9 on puVar58[1] (here pkt):
#
# uVar41 = (w & 0xffffffff) >> 7; local_520 = uVar41 & 1
# local_4a0 = (w >> 8) & 3;
# local_4a8 = (w >> 10) & (7 or 0xf) (depends on local_494)
# uVar69 = (w >> 0xd) or (w >> 0xf) (depends on local_494)
# local_518 = (w >> 0x19) & 0x7f;
#
# We *dont* know local_494 here, so we just expose the raw slices.
flag7 = (pkt >> 7) & 0x1 # low bit of uVar41
cls2 = (pkt >> 8) & 0x3 # local_4a0
slot4 = (pkt >> 10) & 0xF # superset of 3-bit local_4a8
idx_lo = (pkt >> 13) & 0x1F # matches uVar69&0x1F when layout<4
idx_hi = (pkt >> 15) & 0x1F # matches uVar69&0x1F when layout>=4
id7 = (pkt >> 0x19) & 0x7F # local_518
if opcode == 0x09: # PERF_ROUTE_CONFIG
# From case 9 in multiple consumers:
# flag7 = (w >> 7) & 1 (low bit of uVar41)
# cls2 = (w >> 8) & 3 (class / group)
# slot4 = (w >> 10) & 0xf (slot / group index)
# idx_lo = (w >> 0xd) & 0x1f (low index, layout<4 path)
# idx_hi = (w >> 0xf) & 0x1f (high index, layout>=4 path)
# id7 = (w >> 0x19) & 0x7f (7-bit id)
flag7 = (pkt >> 7) & 0x1
cls2 = (pkt >> 8) & 0x3
slot4 = (pkt >> 10) & 0xF
idx_lo = (pkt >> 13) & 0x1F
idx_hi = (pkt >> 15) & 0x1F
id7 = (pkt >> 0x19) & 0x7F
fields.append(f"flag7={flag7}")
fields.append(f"cls2={cls2}")
@@ -243,18 +320,18 @@ def decode_packet_fields(opcode: int, reg: int, delta: int) -> str:
fields.append(f"id7=0x{id7:x}")
return ", ".join(fields)
# --- 5. Opcode 0x18: perf/event trigger ------------------------------
# =====================================================================
# 6. Opcode 0x18: perf/event selector (FUN_0010aba0)
# =====================================================================
if opcode == 0x18:
if opcode == 0x18: # PERF_EVENT_SELECT
# From case 0x18:
# - low 3 bits: (w & 7)
# - mid 3 bits: (w >> 3) & 7 or (w >> 4) & 7 (layoutdependent)
# - hi id: (w >> 0xc) & 0xff OR (w >> 0xd) & 0x7f
# - flag bits at 6 / 7
#
# The *real* semantics depend on global local_494 and accumulated
# local_500, so we keep this as a raw view thats still useful for
# debugging, but not layout-dependent.
# low3 = w & 7
# grp3 = (w >> 3) or (w >> 4) & 7 (layout-dependent)
# flags = bits 6 (B6) and 7 (B7)
# hi8 = (w >> 0xc) & 0xff (layout 4 path)
# hi7 = (w >> 0xd) & 0x7f (other layouts)
# idx5 = (w >> 7) or (w >> 8) & 0x1f, used as wave index
low3 = pkt & 0x7
grp3_a = (pkt >> 3) & 0x7
grp3_b = (pkt >> 4) & 0x7
@@ -276,13 +353,34 @@ def decode_packet_fields(opcode: int, reg: int, delta: int) -> str:
fields.append(f"hi7=0x{hi7:02x}")
return ", ".join(fields)
# --- 6. Generic tiny event-ish packets -------------------------------
# =====================================================================
# 7. Opcode 0x15: perfcounter snapshot
# =====================================================================
if opcode in (0x08, 0x12, 0x19):
# These are all "small event" style tokens. The exact layout depends
# on global state (local_494 etc), so we just show:
# - low 8 bits as a kind/flag byte
# - the rest as an opaque payload.
if opcode == 0x15: # PERFCOUNTER_SNAPSHOT
# NIBBLE_BUDGET gives full 64 bits here.
# DELTA_MAP_DEFAULT: shift=7, width=3 → tiny delta field.
raw_delta = shaped_field if width else 0
# low bits below the delta field
snap_low = pkt & ((1 << shift) - 1) if shift else 0
# everything above delta field
snap_hi = pkt >> (shift + width) if width else (pkt >> shift)
fields.append(f"raw_delta={raw_delta}")
fields.append(f"snap_low_s{shift}=0x{snap_low:x}")
fields.append(f"snap_hi=0x{snap_hi:x}")
return ", ".join(fields)
# =====================================================================
# 8. Small event-ish packets (0x08 / 0x12 / 0x13 / 0x19)
# =====================================================================
if opcode in (0x08, 0x12, 0x13, 0x19):
# These are all "small event / metric" style tokens. The exact semantics
# depend on layout (0x17) and accumulated state (local_500 etc), so we
# expose:
# - low 8 bits as kind byte
# - rest as opaque payload.
kind = pkt & 0xFF
payload = pkt >> 8
fields.append(f"kind_byte=0x{kind:02x}")
@@ -290,10 +388,27 @@ def decode_packet_fields(opcode: int, reg: int, delta: int) -> str:
fields.append(f"payload=0x{payload:x}")
return ", ".join(fields)
# --- 7. Everything else: no extra decode -----------------------------
return ""
# =====================================================================
# 9. Pseudo opcode 0x10: never a "real" packet
# =====================================================================
def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000) -> None:
if opcode == 0x10: # PSEUDO_NEED_MORE_BITS
# The main loop never prints these; they're just a control token.
return ""
# =====================================================================
# 10. Generic fallback: expose the DELTA_MAP_DEFAULT field + leftover
# =====================================================================
if width:
fields.append(f"field_s{shift}_w{width}={shaped_field}")
leftover = pkt & ~(field_mask << shift)
if leftover:
fields.append(f"payload=0x{leftover:x}")
return ", ".join(fields)
def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000, filter=None) -> None:
"""
Minimal debug: print ONE LINE per decoded token (packet).
@@ -350,18 +465,22 @@ def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000) -> None:
if two_bits == 1:
flags |= 0x01
# Common 36-bit field at bits [12..47]
val36 = (reg >> 12) & ((1 << 36) - 1)
if (reg & 0x200) == 0:
# delta mode: 36-bit delta at bits [12..47]
delta = (reg >> 12) & ((1 << 36) - 1)
# delta mode: add 36-bit delta to time
delta = val36
time += delta
note = "0x16-delta"
else:
# marker mode if bit9==1 and bit8==0
if (reg & 0x100) == 0:
val = (reg >> 12) & ((1 << 36) - 1)
# marker / other modes: no time advance
if (reg & 0x100) == 0 and val36 != 0:
# real marker: bit9=1, bit8=0, non-zero payload
delta = 0
note = f"0x16-marker val=0x{val:x}"
note = f"0x16-marker val=0x{val36:x}"
else:
# "other" 0x16 variants, ignored for timing
delta = 0
note = "0x16-other"
else:
@@ -387,8 +506,7 @@ def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000) -> None:
extra = decode_packet_fields(opcode, reg, delta)
if extra: note = (note + " ; " + extra) if note else extra
BORING_OPCODES = {0x11, 0x14}
if opcode not in BORING_OPCODES or getenv("BORING", 1):
if filter is None or opcode not in filter:
my_reg = reg
my_reg &= (1 << nib_budget) - 1
print(

View File

@@ -111,7 +111,6 @@ def decode(profile:list[ProfileEvent]) -> _ROCParseCtx:
@rocprof.rocprof_trace_decoder_isa_callback_t
def isa_cb(instr_ptr, mem_size_ptr, size_ptr, pc, data_ptr):
if DEBUG >= 8: print(f"isa_cb {pc.address=} {pc.code_object_id=}")
instr, mem_size_ptr[0] = ROCParseCtx.disasms[(unwrap(ROCParseCtx.active_kern), pc.address)]
# this is the number of bytes to next instruction, set to 0 for end_pgm

View File

@@ -17,7 +17,7 @@ def helper_collect_profile(*devs):
cpu_events.clear()
profile_list = []
with Context(VIZ=1):
with Context(VIZ=1, PROFILE=1):
yield profile_list
for dev in devs: dev.synchronize()
for dev in devs: dev._at_profile_finalize()

View File

@@ -66,6 +66,7 @@ class TestProgressBar(unittest.TestCase):
tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
self._compare_bars(tinytqdm_output, tqdm_output)
@unittest.skip("this is flaky")
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')
def test_unit_scale(self, mock_terminal_size, mock_stderr):

View File

@@ -6,6 +6,7 @@ from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, TrackedPatternMatch
from tinygrad.uop.symbolic import sym
from tinygrad.dtype import dtypes
from tinygrad.helpers import PROFILE, colored, ansistrip, flatten, TracingKey, ProfileRangeEvent, ProfileEvent, Context, cpu_events, profile_marker
from tinygrad.helpers import VIZ
from tinygrad.device import Buffer
@track_rewrites(name=True)
@@ -33,11 +34,14 @@ class BaseTestViz(unittest.TestCase):
cpu_events.clear()
self.tms = TRACK_MATCH_STATS.value
self.profile = PROFILE.value
self.viz = VIZ.value
TRACK_MATCH_STATS.value = 2
PROFILE.value = 1
VIZ.value = 1
def tearDown(self):
TRACK_MATCH_STATS.value = self.tms
PROFILE.value = self.profile
VIZ.value = self.viz
class TestViz(BaseTestViz):
def test_simple(self):

View File

@@ -5,7 +5,7 @@ from typing import Any, Generic, TypeVar, Iterator, Sequence, cast, Generator
import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored, CPU_LLVM
from tinygrad.helpers import Context, CCACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup
from tinygrad.helpers import unwrap_class_type, suppress_finalizing, AMD_LLVM, select_first_inited
from tinygrad.helpers import unwrap_class_type, suppress_finalizing, AMD_LLVM, select_first_inited, VIZ
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype
from tinygrad.renderer import Renderer
@@ -355,8 +355,9 @@ if PROFILE:
with open(fn:=temp("profile.pkl", append_user=True), "wb") as f: pickle.dump(cpu_events+Compiled.profile_events+Buffer.profile_events, f)
from tinygrad.uop.ops import launch_viz
launch_viz("PROFILE", fn)
if VIZ:
from tinygrad.uop.ops import launch_viz
launch_viz("PROFILE", fn)
def enumerate_devices_str() -> Generator[str, None, None]:
from tinygrad import Tensor, Device

View File

@@ -179,7 +179,9 @@ ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE = ContextVar("ALLOW_DEVICE_USAGE", 1), Conte
EMULATE = ContextVar("EMULATE", "")
CPU_COUNT = ContextVar("CPU_COUNT", max(1, len(os.sched_getaffinity(0)) if hasattr(os, "sched_getaffinity") else (os.cpu_count() or 1)))
CPU_LLVM, CPU_LVP, AMD_LLVM = ContextVar("CPU_LLVM", 0), ContextVar("CPU_LVP", 0), ContextVar("AMD_LLVM", 0)
VIZ = PROFILE = ContextVar("VIZ", 0)
# VIZ implies PROFILE, but you can run PROFILE without VIZ
VIZ = ContextVar("VIZ", 0)
PROFILE = ContextVar("PROFILE", VIZ.value)
SPEC = ContextVar("SPEC", 1)
# TODO: disable by default due to speed
IGNORE_OOB = ContextVar("IGNORE_OOB", 1)