diff --git a/extra/mmapeak/mmapeak.py b/extra/mmapeak/mmapeak.py index 36e24fd372..37b70c8c6e 100644 --- a/extra/mmapeak/mmapeak.py +++ b/extra/mmapeak/mmapeak.py @@ -1,8 +1,10 @@ -import pathlib +import os, pathlib + +# TODO: there is a timing bug without this +os.environ["AMD_AQL"] = "1" + from tinygrad.device import Device from tinygrad.runtime.ops_amd import AMDProgram, HIPCompiler -import time -import os NUM_WORKGROUPS = 96 WAVE_SIZE = 32 @@ -44,9 +46,9 @@ if __name__=="__main__": raise RuntimeError("Error while initiating AMD device") COMPILER = HIPCompiler(DEV.arch) - if DEV.arch in {'gfx1100', 'gfx1103'}: - if DEV.arch == 'gfx1103': - NUM_WORKGROUPS = 8 + if DEV.arch in {'gfx1100', 'gfx1103', 'gfx1151'}: + if DEV.arch == 'gfx1103': NUM_WORKGROUPS = 8 + if DEV.arch == 'gfx1151': NUM_WORKGROUPS = 40 launchBenchmark("v_wmma_bf16_16x16x16_bf16", (7,8,15)) launchBenchmark("v_wmma_f16_16x16x16_f16", (7,8,15)) launchBenchmark("v_wmma_f32_16x16x16_bf16", (7,8,15)) diff --git a/extra/sqtt/active_sqtt_parse.py b/extra/sqtt/active_sqtt_parse.py new file mode 100644 index 0000000000..0d70be863d --- /dev/null +++ b/extra/sqtt/active_sqtt_parse.py @@ -0,0 +1,99 @@ +import os +os.environ["PYTHONPATH"] = "." +os.environ["SQTT"] = "1" +if "DEV" not in os.environ: os.environ["DEV"] = "AMD" +os.environ["PROFILE"] = "1" +os.environ["AMD_LLVM"] = "0" + +from dataclasses import replace +import atexit, contextlib +from tinygrad.helpers import system, getenv +from tinygrad.runtime.ops_amd import AMDProgram +from extra.sqtt.roc import decode, WaveExec, ProfileSQTTEvent +from tinygrad.device import Device, ProfileDeviceEvent + +from extra.sqtt.attempt_sqtt_parse import parse_sqtt_print_packets + +def set_power(x): system(f"sudo /opt/rocm/bin/amd-smi set -l {x}") +@atexit.register +def reset_power(): set_power("auto") +set_power("stable_std") + +dev = Device["AMD"] + +@contextlib.contextmanager +def save_sqtt(): + # clear the old traces + dev.profile_events.clear() + sqtt:dict[str, list[WaveExec]] = {} + yield sqtt + events = dev.profile_events+[ProfileDeviceEvent("AMD", props=dev.device_props())] + + rctx = decode(events) + assert len(rctx.inst_execs) > 0, "empty sqtt output" + sqtt.update(rctx.inst_execs) + + for e in events: + if isinstance(e, ProfileSQTTEvent): + print(replace(e, blob=b'')) + if e.se == 0: + parse_sqtt_print_packets(e.blob, filter=[0xf, 0x11, 0x12, 0x14] if getenv("FILTER", 1) else None) + + +template = """.text +.globl matmul +.p2align 8 +.type matmul,@function +matmul: + INSTRUCTION + s_endpgm + +.rodata +.p2align 6 +.amdhsa_kernel matmul + .amdhsa_next_free_vgpr .amdgcn.next_free_vgpr + .amdhsa_next_free_sgpr .amdgcn.next_free_sgpr + .amdhsa_wavefront_size32 1 +.end_amdhsa_kernel + +.amdgpu_metadata +--- +amdhsa.version: + - 1 + - 0 +amdhsa.kernels: + - .name: matmul + .symbol: matmul.kd + .kernarg_segment_size: 0 + .group_segment_fixed_size: 0 + .private_segment_fixed_size: 0 + .kernarg_segment_align: 4 + .wavefront_size: 32 + .sgpr_count: 8 + .vgpr_count: 32 + .max_flat_workgroup_size: 1024 +... +.end_amdgpu_metadata +""" + +def run_asm(src): + NUM_WORKGROUPS = 1 + WAVE_SIZE = 32 + NUM_WAVES = 1 + lib = dev.compiler.compile(template.replace("INSTRUCTION", '\n'.join(src))) + dev.compiler.disassemble(lib) + fxn = AMDProgram(dev, "matmul", lib) + fxn(global_size=(NUM_WORKGROUPS,1,1), local_size=(WAVE_SIZE*NUM_WAVES,1,1), wait=True) + +if __name__ == "__main__": + with save_sqtt() as sqtt: + run_asm([ + #"v_rcp_f32 v1, v0" + "v_add_f32_e32 v1 v0 v0", + "v_add_f32_e32 v3 v2 v2", + "v_add_f32_e32 v5 v4 v4", + "v_add_f32_e32 v7 v6 v6", + #"v_add_f32_e32 v1 v0 v0", + #"v_add_f32_e32 v2 v1 v1", + #"s_nop 1" + ]*1) diff --git a/extra/sqtt/attempt_sqtt_parse.py b/extra/sqtt/attempt_sqtt_parse.py index 939e1f2e36..996aeed964 100644 --- a/extra/sqtt/attempt_sqtt_parse.py +++ b/extra/sqtt/attempt_sqtt_parse.py @@ -1,44 +1,48 @@ import pickle -from hexdump import hexdump from extra.sqtt.roc import decode, ProfileSQTTEvent -from tinygrad.helpers import getenv # Instruction packets (one per ISA op) # NOTE: these are bad guesses and may be wrong! feel free to update if you know better OPCODE_NAMES = { - # Small metadata / structural packets (NOT ISA op kinds) - 0x01: "META_SMALL_ID", # 12-bit identifier / slot tag - 0x02: "META_FLAG", # 1-byte flag/mode (CF/AF/8F/DF...) - 0x03: "META_SUBEVENT_CODE", # 1-byte sub-event/classification code - 0x04: "META_BASE_INDEX_TAG", # 12-bit base index/tag (..D, 9D, 10D, 58D...) + # ------------------------------------------------------------------------ + # 0x01–0x06: small “meta + maybe tiny delta” packets + # ------------------------------------------------------------------------ + 0x01: "META_ID12_TS_SMALL", # 12-bit ID + 3-bit delta field + 0x02: "META_FLAG8_TS_SMALL", # 8-bit flag/mode + small delta + 0x03: "META_SUBEVENT8_TS_SMALL", # 8-bit subevent/class + small delta + 0x04: "META_BASE_INDEX12_TS", # 12-bit base index + small delta + 0x05: "META_DESC24_TS_A", # 24-bit descriptor-ish + delta field + 0x06: "META_DESC24_TS_B", # second flavour, 24-bit, delta field - # Instruction / timing / timestamp packets - 0x0F: "TIME_SHORT_DELTA_PLUS4", # short ts, raw_delta+4 - 0x11: "TIME_WAVE_STATE", # compact wave timing/stall state record - 0x14: "INST_EXEC_RECORD", # per-instruction execution record - 0x16: "TIME_LONG_OR_MARKER", # long delta / marker with 6-byte payload + # ------------------------------------------------------------------------ + # 0x07–0x0F: pure timestamp-ish deltas + # ------------------------------------------------------------------------ + 0x07: "TS_DELTA_S8_W3", # shift=8, width=3 (small delta) + 0x08: "EVT_MATCH_SMALL", # event-ish, see fields below + 0x09: "PERF_ROUTE_CONFIG", # routing/indirection config + 0x0A: "TS_DELTA_S5_W2_A", # shift=5, width=2 + 0x0B: "TS_DELTA_S5_W3_A", # shift=5, width=3 + 0x0C: "TS_DELTA_S5_W3_B", # shift=5, width=3 (different consumer) + 0x0D: "TS_DELTA_S5_W3_C", # shift=5, width=3 + 0x0E: "TS_DELTA_S7_W2", # shift=7, width=2 + 0x0F: "TS_DELTA_SHORT_PLUS4", # short delta; ROCm adds +4 before accumulate - # State / control / perf snapshots - 0x09: "CONTROL_CONFIG_32B", # 32-bit control/config word (bursts of FE88..., C488...) - 0x15: "PERFCOUNTER_SNAPSHOT", # perf / TT configuration snapshot (8-byte) + # ------------------------------------------------------------------------ + # 0x10–0x19: timestamps, layout headers, events, perf + # ------------------------------------------------------------------------ + 0x10: "PSEUDO_NEED_MORE_BITS", # not a real packet; decoder refill hint - # Extra descriptors / events / metrics - 0x06: "META_DESCRIPTOR_24B", # 24-bit descriptor (seen in complex kernels like GEMM) - 0x08: "EVENT_SMALL", # small in-stream event (5-nibble payload) - 0x12: "TIME_SECONDARY_METRIC", # 3-byte secondary timing/latency/perf metric - 0x18: "EVENT_SMALL_PAYLOAD", # generic small side-band payload (5 nibbles) - 0x19: "EVENT_SUMMARY_48B", # rare 6-byte summary/aggregate metric + 0x11: "TS_WAVE_STATE_SAMPLE", # wave stall/termination sample (byte at +10) + 0x12: "EVT_SECONDARY_METRIC24", # 24-bit secondary timing/perf metric + 0x13: "EVT_SMALL_GENERIC", # same structural family as 0x08/0x12/0x19 - # Pseudo / unknown / not yet observed - 0x07: "UNK_DELTA", # unknown - 0x0A: "UNK_DELTA2", # unknown - 0x0B: "UNK_DELTA3", # unknown - 0x0C: "UNK_DELTA4", # unknown - 0x0D: "UNK_DELTA5", # unknown - 0x0E: "UNK_DELTA6", # unknown - 0x10: "UNK_PSEUDO", # not seen; pseudo/placeholder - 0x17: "UNK_NO_DELTA", # unknown, likely non-timing event + 0x14: "INST_EXEC_OR_CFG", # instruction exec record / config write / COR marker + 0x15: "PERFCOUNTER_SNAPSHOT", # small delta + 50-ish bits of snapshot + 0x16: "TS_DELTA36_OR_MARK", # 36-bit long delta or 36-bit marker + 0x17: "LAYOUT_MODE_HEADER", # layout/mode/group + selectors A/B + 0x18: "PERF_EVENT_SELECT", # packed selector → FUN_0010aba0 + 0x19: "EVT_SUMMARY_48B", # 6-byte summary/aggregate metric } # these tables are from rocprof trace decoder @@ -113,17 +117,13 @@ DELTA_MAP_DEFAULT = { def decode_packet_fields(opcode: int, reg: int, delta: int) -> str: """ - Conservative decoding of a few packet types. - - Rules: - - We first mask the 64-bit shift register down to the actual packet - width using NIBBLE_BUDGET[opcode & 0x1F], so we never read bits - that aren't really part of the packet. - - Only layouts that are clearly visible from the decompiled C are - decoded, and names are kept generic (cfg_*, idx_*, id_*, etc). + Decode packet payloads conservatively, using: + - NIBBLE_BUDGET[opcode & 0x1F] to mask reg down to true width. + - DELTA_MAP_DEFAULT[opcode] to expose the "primary" field (often delta). + - Per-opcode layouts derived from rocprof's decompiled consumers. """ - # --- 0. Restrict to the real packet bits for this opcode ------------- - nb_bits = NIBBLE_BUDGET[opcode & 0x1F] # this table is in bits + # --- 0. Restrict to real packet bits --------------------------------- + nb_bits = NIBBLE_BUDGET[opcode & 0x1F] if nb_bits <= 0 or nb_bits >= 64: pkt = reg & ((1 << 64) - 1) else: @@ -131,23 +131,46 @@ def decode_packet_fields(opcode: int, reg: int, delta: int) -> str: fields: list[str] = [] - # --- 1. Timestamp-ish opcodes ---------------------------------------- + shift, width = DELTA_MAP_DEFAULT.get(opcode, (0, 0)) + if width: + field_mask = (1 << width) - 1 + shaped_field = (pkt >> shift) & field_mask + else: + field_mask = 0 + shaped_field = 0 - if opcode == 0x0F: # TIME_SHORT_DELTA_PLUS4 - # By the time we get here, `delta` is already raw_delta+4. + # ===================================================================== + # 1. Timestamp-centric opcodes (actually drive 'time') + # ===================================================================== + + if opcode == 0x0F: # TS_DELTA_SHORT_PLUS4 + # In the caller, delta already has +4 applied. + raw_delta = shaped_field + fields.append(f"raw_delta={raw_delta}") fields.append(f"ts_short_plus4={delta}") return ", ".join(fields) - if opcode == 0x11: # TIME_WAVE_STATE (medium/large delta) - shift, width = DELTA_MAP_DEFAULT[opcode] - raw_delta = (pkt >> shift) & ((1 << width) - 1) - coarse = (pkt >> (shift + width)) & 0xFF # next byte above delta + if opcode == 0x11: # TS_WAVE_STATE_SAMPLE + # DELTA_MAP_DEFAULT: shift=7, width=9 -> small delta. + raw_delta = shaped_field + coarse = (pkt >> (shift + width)) & 0xFF # matches byte at +10 in C fields.append(f"raw_delta={raw_delta}") if coarse: - fields.append(f"raw_coarse=0x{coarse:02x}") + fields.append(f"coarse_state=0x{coarse:02x}") + # From decomp: + # - when layout<3 and coarse&1, it sets a "has interesting wave" flag + # - when coarse&8, it marks all live waves as "terminated" + if coarse & 0x01: + fields.append("flag_wave_interest=1") + if coarse & 0x08: + fields.append("flag_terminate_all=1") return ", ".join(fields) - if opcode == 0x16: # TIME_LONG_OR_MARKER + if opcode == 0x16: # TS_DELTA36_OR_MARK + # Bits: + # bit8 -> 0x100 + # bit9 -> 0x200 + # bits 12..47 -> 36-bit field used as delta or marker bit8 = bool(pkt & 0x100) bit9 = bool(pkt & 0x200) if not bit9: @@ -159,40 +182,95 @@ def decode_packet_fields(opcode: int, reg: int, delta: int) -> str: val36 = (pkt >> 12) & ((1 << 36) - 1) fields.append(f"mode={mode}") fields.append(f"val36=0x{val36:x}") + if mode == "delta": + fields.append(f"delta36={delta}") return ", ".join(fields) - # --- 2. Opcode 0x14: exec/config record ------------------------------ + # For 0x07, 0x0A–0x0E, we know they drive time (via DELTA_MAP_DEFAULT), + # but we don't see any other fields used in the decomp. + if opcode in (0x07, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E): + if width: + raw_delta = shaped_field + leftover = pkt & ~(field_mask << shift) + fields.append(f"raw_delta={raw_delta}") + if leftover: + fields.append(f"payload=0x{leftover:x}") + return ", ".join(fields) - if opcode == 0x14: - subop = (pkt >> 16) & 0xFFFF # matches (short)(w >> 0x10) - val32 = (pkt >> 32) & 0xFFFFFFFF # matches (uint)(w >> 0x20) - slot = (pkt >> 7) & 0x7 # used as (idx & 4) + (idx & 3) - hi_byte = (pkt >> 8) & 0xFF + # ===================================================================== + # 2. Small "meta + tiny delta" packets (0x01–0x06) + # ===================================================================== + + if opcode == 0x01: # META_ID12_TS_SMALL + id12 = pkt & 0xFFF + fields.append(f"id12=0x{id12:03x}") + if width: + fields.append(f"field_s{shift}_w{width}={shaped_field}") + return ", ".join(fields) + + if opcode == 0x02: # META_FLAG8_TS_SMALL + flag8 = pkt & 0xFF + fields.append(f"flag8=0x{flag8:02x}") + if width: + fields.append(f"field_s{shift}_w{width}={shaped_field}") + return ", ".join(fields) + + if opcode == 0x03: # META_SUBEVENT8_TS_SMALL + sub8 = pkt & 0xFF + fields.append(f"subevent8=0x{sub8:02x}") + if width: + fields.append(f"field_s{shift}_w{width}={shaped_field}") + return ", ".join(fields) + + if opcode == 0x04: # META_BASE_INDEX12_TS + idx12 = pkt & 0xFFF + fields.append(f"base_index12=0x{idx12:03x}") + if width: + fields.append(f"field_s{shift}_w{width}={shaped_field}") + return ", ".join(fields) + + if opcode in (0x05, 0x06): # META_DESC24_TS_A/B + desc24 = pkt & 0xFFFFFF + fields.append(f"desc24=0x{desc24:06x}") + if width: + fields.append(f"field_s{shift}_w{width}={shaped_field}") + return ", ".join(fields) + + # ===================================================================== + # 3. Opcode 0x14: exec/config record (+ COR marker) + # ===================================================================== + + if opcode == 0x14: # INST_EXEC_OR_CFG + subop = (pkt >> 16) & 0xFFFF # (short)(w >> 0x10) + val32 = (pkt >> 32) & 0xFFFFFFFF # (uint)(w >> 0x20) + slot = (pkt >> 7) & 0x7 # index in local_168[...] tables + hi_byte = (pkt >> 8) & 0xFF # determines config vs marker fields.append(f"subop=0x{subop:04x}") fields.append(f"slot={slot}") fields.append(f"val32=0x{val32:08x}") if hi_byte & 0x80: - # "config" flavour, writes into local_168[...] etc. + # Config flavour: writes config words into per-slot state arrays. fields.append("kind=config") if subop == 0x000C: fields.append("cfg_target=local_168[slot].lo") elif subop == 0x000D: fields.append("cfg_target=local_168[slot].hi") else: - # COR marker: subop 0xC342, val32==0x434F5200 ("COR\0") + # COR marker: subop 0xC342, payload "COR\0" → start of a COR region. if subop == 0xC342: fields.append("kind=cor_stream") if val32 == 0x434F5200: fields.append("cor_magic='COR\\0'") - return ", ".join(fields) - # --- 3. Opcode 0x17: mode/layout header ------------------------------ + # ===================================================================== + # 4. Opcode 0x17: layout / mode header + # ===================================================================== - if opcode == 0x17: - # From case 0x17: + if opcode == 0x17: # LAYOUT_MODE_HEADER + # From decomp (two sites with identical logic): # layout = (w >> 7) & 0x3f # mode = (w >> 0xd) & 3 # group = (w >> 0xf) & 7 @@ -213,27 +291,26 @@ def decode_packet_fields(opcode: int, reg: int, delta: int) -> str: fields.append(f"sel_b={sel_b}") if layout == 4: fields.append(f"layout4_flag={flag4}") - return ", ".join(fields) - # --- 4. Opcode 0x09: state-ish / indirection record ------------------ + # ===================================================================== + # 5. Opcode 0x09: state / route config record + # ===================================================================== - if opcode == 0x09: - # From case 9 on puVar58[1] (here pkt): - # - # uVar41 = (w & 0xffffffff) >> 7; local_520 = uVar41 & 1 - # local_4a0 = (w >> 8) & 3; - # local_4a8 = (w >> 10) & (7 or 0xf) (depends on local_494) - # uVar69 = (w >> 0xd) or (w >> 0xf) (depends on local_494) - # local_518 = (w >> 0x19) & 0x7f; - # - # We *don’t* know local_494 here, so we just expose the raw slices. - flag7 = (pkt >> 7) & 0x1 # low bit of uVar41 - cls2 = (pkt >> 8) & 0x3 # local_4a0 - slot4 = (pkt >> 10) & 0xF # superset of 3-bit local_4a8 - idx_lo = (pkt >> 13) & 0x1F # matches uVar69&0x1F when layout<4 - idx_hi = (pkt >> 15) & 0x1F # matches uVar69&0x1F when layout>=4 - id7 = (pkt >> 0x19) & 0x7F # local_518 + if opcode == 0x09: # PERF_ROUTE_CONFIG + # From case 9 in multiple consumers: + # flag7 = (w >> 7) & 1 (low bit of uVar41) + # cls2 = (w >> 8) & 3 (class / group) + # slot4 = (w >> 10) & 0xf (slot / group index) + # idx_lo = (w >> 0xd) & 0x1f (low index, layout<4 path) + # idx_hi = (w >> 0xf) & 0x1f (high index, layout>=4 path) + # id7 = (w >> 0x19) & 0x7f (7-bit id) + flag7 = (pkt >> 7) & 0x1 + cls2 = (pkt >> 8) & 0x3 + slot4 = (pkt >> 10) & 0xF + idx_lo = (pkt >> 13) & 0x1F + idx_hi = (pkt >> 15) & 0x1F + id7 = (pkt >> 0x19) & 0x7F fields.append(f"flag7={flag7}") fields.append(f"cls2={cls2}") @@ -243,18 +320,18 @@ def decode_packet_fields(opcode: int, reg: int, delta: int) -> str: fields.append(f"id7=0x{id7:x}") return ", ".join(fields) - # --- 5. Opcode 0x18: perf/event trigger ------------------------------ + # ===================================================================== + # 6. Opcode 0x18: perf/event selector (FUN_0010aba0) + # ===================================================================== - if opcode == 0x18: + if opcode == 0x18: # PERF_EVENT_SELECT # From case 0x18: - # - low 3 bits: (w & 7) - # - mid 3 bits: (w >> 3) & 7 or (w >> 4) & 7 (layout–dependent) - # - hi id: (w >> 0xc) & 0xff OR (w >> 0xd) & 0x7f - # - flag bits at 6 / 7 - # - # The *real* semantics depend on global local_494 and accumulated - # local_500, so we keep this as a raw view that’s still useful for - # debugging, but not layout-dependent. + # low3 = w & 7 + # grp3 = (w >> 3) or (w >> 4) & 7 (layout-dependent) + # flags = bits 6 (B6) and 7 (B7) + # hi8 = (w >> 0xc) & 0xff (layout 4 path) + # hi7 = (w >> 0xd) & 0x7f (other layouts) + # idx5 = (w >> 7) or (w >> 8) & 0x1f, used as wave index low3 = pkt & 0x7 grp3_a = (pkt >> 3) & 0x7 grp3_b = (pkt >> 4) & 0x7 @@ -276,13 +353,34 @@ def decode_packet_fields(opcode: int, reg: int, delta: int) -> str: fields.append(f"hi7=0x{hi7:02x}") return ", ".join(fields) - # --- 6. Generic tiny event-ish packets ------------------------------- + # ===================================================================== + # 7. Opcode 0x15: perfcounter snapshot + # ===================================================================== - if opcode in (0x08, 0x12, 0x19): - # These are all "small event" style tokens. The exact layout depends - # on global state (local_494 etc), so we just show: - # - low 8 bits as a kind/flag byte - # - the rest as an opaque payload. + if opcode == 0x15: # PERFCOUNTER_SNAPSHOT + # NIBBLE_BUDGET gives full 64 bits here. + # DELTA_MAP_DEFAULT: shift=7, width=3 → tiny delta field. + raw_delta = shaped_field if width else 0 + # low bits below the delta field + snap_low = pkt & ((1 << shift) - 1) if shift else 0 + # everything above delta field + snap_hi = pkt >> (shift + width) if width else (pkt >> shift) + + fields.append(f"raw_delta={raw_delta}") + fields.append(f"snap_low_s{shift}=0x{snap_low:x}") + fields.append(f"snap_hi=0x{snap_hi:x}") + return ", ".join(fields) + + # ===================================================================== + # 8. Small event-ish packets (0x08 / 0x12 / 0x13 / 0x19) + # ===================================================================== + + if opcode in (0x08, 0x12, 0x13, 0x19): + # These are all "small event / metric" style tokens. The exact semantics + # depend on layout (0x17) and accumulated state (local_500 etc), so we + # expose: + # - low 8 bits as kind byte + # - rest as opaque payload. kind = pkt & 0xFF payload = pkt >> 8 fields.append(f"kind_byte=0x{kind:02x}") @@ -290,10 +388,27 @@ def decode_packet_fields(opcode: int, reg: int, delta: int) -> str: fields.append(f"payload=0x{payload:x}") return ", ".join(fields) - # --- 7. Everything else: no extra decode ----------------------------- - return "" + # ===================================================================== + # 9. Pseudo opcode 0x10: never a "real" packet + # ===================================================================== -def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000) -> None: + if opcode == 0x10: # PSEUDO_NEED_MORE_BITS + # The main loop never prints these; they're just a control token. + return "" + + # ===================================================================== + # 10. Generic fallback: expose the DELTA_MAP_DEFAULT field + leftover + # ===================================================================== + + if width: + fields.append(f"field_s{shift}_w{width}={shaped_field}") + leftover = pkt & ~(field_mask << shift) + if leftover: + fields.append(f"payload=0x{leftover:x}") + + return ", ".join(fields) + +def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000, filter=None) -> None: """ Minimal debug: print ONE LINE per decoded token (packet). @@ -350,18 +465,22 @@ def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000) -> None: if two_bits == 1: flags |= 0x01 + # Common 36-bit field at bits [12..47] + val36 = (reg >> 12) & ((1 << 36) - 1) + if (reg & 0x200) == 0: - # delta mode: 36-bit delta at bits [12..47] - delta = (reg >> 12) & ((1 << 36) - 1) + # delta mode: add 36-bit delta to time + delta = val36 time += delta note = "0x16-delta" else: - # marker mode if bit9==1 and bit8==0 - if (reg & 0x100) == 0: - val = (reg >> 12) & ((1 << 36) - 1) + # marker / other modes: no time advance + if (reg & 0x100) == 0 and val36 != 0: + # real marker: bit9=1, bit8=0, non-zero payload delta = 0 - note = f"0x16-marker val=0x{val:x}" + note = f"0x16-marker val=0x{val36:x}" else: + # "other" 0x16 variants, ignored for timing delta = 0 note = "0x16-other" else: @@ -387,8 +506,7 @@ def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000) -> None: extra = decode_packet_fields(opcode, reg, delta) if extra: note = (note + " ; " + extra) if note else extra - BORING_OPCODES = {0x11, 0x14} - if opcode not in BORING_OPCODES or getenv("BORING", 1): + if filter is None or opcode not in filter: my_reg = reg my_reg &= (1 << nib_budget) - 1 print( diff --git a/extra/sqtt/roc.py b/extra/sqtt/roc.py index a785b6eaca..9d4317c7b0 100644 --- a/extra/sqtt/roc.py +++ b/extra/sqtt/roc.py @@ -111,7 +111,6 @@ def decode(profile:list[ProfileEvent]) -> _ROCParseCtx: @rocprof.rocprof_trace_decoder_isa_callback_t def isa_cb(instr_ptr, mem_size_ptr, size_ptr, pc, data_ptr): - if DEBUG >= 8: print(f"isa_cb {pc.address=} {pc.code_object_id=}") instr, mem_size_ptr[0] = ROCParseCtx.disasms[(unwrap(ROCParseCtx.active_kern), pc.address)] # this is the number of bytes to next instruction, set to 0 for end_pgm diff --git a/test/test_profiler.py b/test/test_profiler.py index 2836aa4432..c86b2551c7 100644 --- a/test/test_profiler.py +++ b/test/test_profiler.py @@ -17,7 +17,7 @@ def helper_collect_profile(*devs): cpu_events.clear() profile_list = [] - with Context(VIZ=1): + with Context(VIZ=1, PROFILE=1): yield profile_list for dev in devs: dev.synchronize() for dev in devs: dev._at_profile_finalize() diff --git a/test/unit/test_tqdm.py b/test/unit/test_tqdm.py index 4f9581d595..d3a0b350b5 100644 --- a/test/unit/test_tqdm.py +++ b/test/unit/test_tqdm.py @@ -66,6 +66,7 @@ class TestProgressBar(unittest.TestCase): tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix="Test") self._compare_bars(tinytqdm_output, tqdm_output) + @unittest.skip("this is flaky") @patch('sys.stderr', new_callable=StringIO) @patch('shutil.get_terminal_size') def test_unit_scale(self, mock_terminal_size, mock_stderr): diff --git a/test/unit/test_viz.py b/test/unit/test_viz.py index e51f85b1fa..36d69d1a44 100644 --- a/test/unit/test_viz.py +++ b/test/unit/test_viz.py @@ -6,6 +6,7 @@ from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, TrackedPatternMatch from tinygrad.uop.symbolic import sym from tinygrad.dtype import dtypes from tinygrad.helpers import PROFILE, colored, ansistrip, flatten, TracingKey, ProfileRangeEvent, ProfileEvent, Context, cpu_events, profile_marker +from tinygrad.helpers import VIZ from tinygrad.device import Buffer @track_rewrites(name=True) @@ -33,11 +34,14 @@ class BaseTestViz(unittest.TestCase): cpu_events.clear() self.tms = TRACK_MATCH_STATS.value self.profile = PROFILE.value + self.viz = VIZ.value TRACK_MATCH_STATS.value = 2 PROFILE.value = 1 + VIZ.value = 1 def tearDown(self): TRACK_MATCH_STATS.value = self.tms PROFILE.value = self.profile + VIZ.value = self.viz class TestViz(BaseTestViz): def test_simple(self): diff --git a/tinygrad/device.py b/tinygrad/device.py index d7e7f2e2bb..b2b8e1fbc4 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -5,7 +5,7 @@ from typing import Any, Generic, TypeVar, Iterator, Sequence, cast, Generator import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored, CPU_LLVM from tinygrad.helpers import Context, CCACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup -from tinygrad.helpers import unwrap_class_type, suppress_finalizing, AMD_LLVM, select_first_inited +from tinygrad.helpers import unwrap_class_type, suppress_finalizing, AMD_LLVM, select_first_inited, VIZ from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype from tinygrad.renderer import Renderer @@ -355,8 +355,9 @@ if PROFILE: with open(fn:=temp("profile.pkl", append_user=True), "wb") as f: pickle.dump(cpu_events+Compiled.profile_events+Buffer.profile_events, f) - from tinygrad.uop.ops import launch_viz - launch_viz("PROFILE", fn) + if VIZ: + from tinygrad.uop.ops import launch_viz + launch_viz("PROFILE", fn) def enumerate_devices_str() -> Generator[str, None, None]: from tinygrad import Tensor, Device diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 92e3add091..87c67dd47b 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -179,7 +179,9 @@ ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE = ContextVar("ALLOW_DEVICE_USAGE", 1), Conte EMULATE = ContextVar("EMULATE", "") CPU_COUNT = ContextVar("CPU_COUNT", max(1, len(os.sched_getaffinity(0)) if hasattr(os, "sched_getaffinity") else (os.cpu_count() or 1))) CPU_LLVM, CPU_LVP, AMD_LLVM = ContextVar("CPU_LLVM", 0), ContextVar("CPU_LVP", 0), ContextVar("AMD_LLVM", 0) -VIZ = PROFILE = ContextVar("VIZ", 0) +# VIZ implies PROFILE, but you can run PROFILE without VIZ +VIZ = ContextVar("VIZ", 0) +PROFILE = ContextVar("PROFILE", VIZ.value) SPEC = ContextVar("SPEC", 1) # TODO: disable by default due to speed IGNORE_OOB = ContextVar("IGNORE_OOB", 1)