improve DEBUG=3 [pr] (#9105)

This commit is contained in:
George Hotz
2025-02-15 18:44:56 +08:00
committed by GitHub
parent 41d143d27c
commit 81f5a7af7d
3 changed files with 8 additions and 6 deletions

View File

@@ -514,7 +514,7 @@ class Kernel:
for axis in to_upcast[::-1]: self.apply_opt(Opt(OptOps.UPCAST, axis, 0))
# potentially do more upcasts of non reduce axes based on a heuristic
upcasted_axis = set()
upcasted_axis: set[int] = set()
while resolve(prod(self.sts[0].shape[:self.first_reduce]) >= 1024):
xb_choices = []
for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
@@ -669,8 +669,8 @@ class Kernel:
if DEBUG >= 3:
print(self.name)
if getenv("RAWAST"): print(self.ast)
print(modified_ast)
for i,(buf,st) in enumerate([(buf,st) for buf,st in zip(self.bufs, self.sts) if buf.op not in {Ops.CONST, Ops.VALID}]):
print(f"{i:2d}: {str(st.shape):25s} {str(buf.src[0].dtype).replace('dtypes.',''):20s}", st.real_strides())
print(self.applied_opts)
# verify AST matches the spec after applying opts
if __debug__: type_verify(list(modified_ast.toposort))

View File

@@ -269,11 +269,13 @@ sym = symbolic_flat+PatternMatcher([
(UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
# push all GEPs through ALUs (fix arange stuff)
(UPat(Ops.GEP, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg)),
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
if not isinstance(gep.dtype, PtrDType) else None),
# push some GEPs through WMMAs
(UPat(Ops.GEP, src=(UPat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma),
# CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later)
(UPat(Ops.CAT, name="x"), lambda x: UOp(Ops.VECTORIZE, x.dtype, tuple(y.gep(i) for y in x.src for i in range(y.dtype.count)))),
(UPat(Ops.CAT, name="x"), lambda x: UOp(Ops.VECTORIZE, x.dtype, tuple(y.gep(i) for y in x.src for i in range(y.dtype.count))) \
if not isinstance(x.dtype, PtrDType) else None),
# tensor core with a 0 input is acc
(UPat(Ops.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc),
(UPat(Ops.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc),

View File

@@ -275,7 +275,7 @@ class MockDSPProgram:
os.chmod(dsp_lib.name, 0o0777)
# NOTE: this timing includes a docker launch
proc = subprocess.run(["docker", "run", "--rm", "-i", "-v", f"{os.path.abspath(os.path.dirname(dsp_lib.name))}:/work", "-w", "/work",
"qemu-hexagon", "-c", f"qemu-hexagon {'-strace' if DEBUG >= 3 else ''} /work/"+os.path.basename(dsp_lib.name)],
"qemu-hexagon", "-c", f"qemu-hexagon {'-strace' if DEBUG >= 5 else ''} /work/"+os.path.basename(dsp_lib.name)],
input=b''.join([bytes(x) for x in bufs] + [struct.pack("I", x) for x in vals]), stdout=subprocess.PIPE, check=True)
offset = 4
for x in bufs: