hevc: assert and speed (#15122)

* hevc: assert and speed

* simpler
This commit is contained in:
nimlgen
2026-03-04 19:01:02 +03:00
committed by GitHub
parent 4e9b85ecfd
commit cdc48da9cd
2 changed files with 8 additions and 4 deletions

View File

@@ -332,7 +332,7 @@ jobs:
# - name: Fuzz Padded Tensor Core GEMM (PTX)
# run: NV=1 NV_PTX=1 M_START=12 M_STOP=20 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 python3 ./extra/gemm/fuzz_matmul.py
- name: HEVC Decode Benchmark
run: VALIDATE=1 MAX_FRAMES=100 JITBEAM=1 NV=1 PYTHONPATH=. python3 extra/hevc/decode.py
run: VALIDATE=1 MAX_FRAMES=100 ASSERT_FPS=1400 JITBEAM=1 NV=1 PYTHONPATH=. python3 extra/hevc/decode.py
- name: Train MNIST
run: time PYTHONPATH=. NV=1 TARGET_EVAL_ACC_PCT=96.0 python3 examples/beautiful_mnist.py
- name: Run 10 CIFAR training steps

View File

@@ -10,9 +10,9 @@ HEVC_ROUNDUP = getenv("DATA_ROUNDUP", 32)
@functools.cache
def _hevc_jitted_decoder(out_image_size:tuple[int, int], max_hist:int, inplace:bool):
def hevc_decode_frame(pos:Variable, hevc_tensor:Tensor, offset:Variable, sz:Variable, opaque:Tensor, i:Variable, *hist:Tensor, outbuf:Tensor|None=None):
x = hevc_tensor[offset:offset+sz*HEVC_ROUNDUP].decode_hevc_frame(pos, out_image_size, opaque[i], hist)
x = hevc_tensor[offset:offset+sz*HEVC_ROUNDUP].decode_hevc_frame(pos, out_image_size, opaque[i], hist).realize()
if outbuf is not None: outbuf.assign(x).realize()
return x.realize()
return x
return TinyJit(hevc_decode_frame)
def hevc_decode(hevc_tensor:Tensor, opaque:Tensor, frame_info:list, luma_h:int, luma_w:int,
@@ -74,10 +74,14 @@ if __name__ == "__main__":
Device.default.synchronize()
# decode all frames using the iterator
with Timing("decoding whole file: ", on_exit=(lambda et: f", {len(frame_info)} frames, {len(frame_info)/(et/1e9):.2f} fps")):
tm = Timing("decoding whole file: ", on_exit=(lambda et: f", {len(frame_info)} frames, {len(frame_info)/(et/1e9):.2f} fps"))
with tm:
images = list(hevc_decode(hevc_tensor, opaque_nv, frame_info, luma_h, luma_w, history=hist, preallocated_outputs=out_images))
Device.default.synchronize()
fps = len(frame_info)/(tm.et/1e9)
assert fps >= getenv("ASSERT_FPS", 0), f"HEVC decode too slow: {fps:.2f} fps"
# validation
if getenv("VALIDATE", 0):
import pickle