mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-08 05:54:59 +08:00
2
.github/workflows/benchmark.yml
vendored
2
.github/workflows/benchmark.yml
vendored
@@ -332,7 +332,7 @@ jobs:
|
||||
# - name: Fuzz Padded Tensor Core GEMM (PTX)
|
||||
# run: NV=1 NV_PTX=1 M_START=12 M_STOP=20 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 python3 ./extra/gemm/fuzz_matmul.py
|
||||
- name: HEVC Decode Benchmark
|
||||
run: VALIDATE=1 MAX_FRAMES=100 JITBEAM=1 NV=1 PYTHONPATH=. python3 extra/hevc/decode.py
|
||||
run: VALIDATE=1 MAX_FRAMES=100 ASSERT_FPS=1400 JITBEAM=1 NV=1 PYTHONPATH=. python3 extra/hevc/decode.py
|
||||
- name: Train MNIST
|
||||
run: time PYTHONPATH=. NV=1 TARGET_EVAL_ACC_PCT=96.0 python3 examples/beautiful_mnist.py
|
||||
- name: Run 10 CIFAR training steps
|
||||
|
||||
@@ -10,9 +10,9 @@ HEVC_ROUNDUP = getenv("DATA_ROUNDUP", 32)
|
||||
@functools.cache
|
||||
def _hevc_jitted_decoder(out_image_size:tuple[int, int], max_hist:int, inplace:bool):
|
||||
def hevc_decode_frame(pos:Variable, hevc_tensor:Tensor, offset:Variable, sz:Variable, opaque:Tensor, i:Variable, *hist:Tensor, outbuf:Tensor|None=None):
|
||||
x = hevc_tensor[offset:offset+sz*HEVC_ROUNDUP].decode_hevc_frame(pos, out_image_size, opaque[i], hist)
|
||||
x = hevc_tensor[offset:offset+sz*HEVC_ROUNDUP].decode_hevc_frame(pos, out_image_size, opaque[i], hist).realize()
|
||||
if outbuf is not None: outbuf.assign(x).realize()
|
||||
return x.realize()
|
||||
return x
|
||||
return TinyJit(hevc_decode_frame)
|
||||
|
||||
def hevc_decode(hevc_tensor:Tensor, opaque:Tensor, frame_info:list, luma_h:int, luma_w:int,
|
||||
@@ -74,10 +74,14 @@ if __name__ == "__main__":
|
||||
Device.default.synchronize()
|
||||
|
||||
# decode all frames using the iterator
|
||||
with Timing("decoding whole file: ", on_exit=(lambda et: f", {len(frame_info)} frames, {len(frame_info)/(et/1e9):.2f} fps")):
|
||||
tm = Timing("decoding whole file: ", on_exit=(lambda et: f", {len(frame_info)} frames, {len(frame_info)/(et/1e9):.2f} fps"))
|
||||
with tm:
|
||||
images = list(hevc_decode(hevc_tensor, opaque_nv, frame_info, luma_h, luma_w, history=hist, preallocated_outputs=out_images))
|
||||
Device.default.synchronize()
|
||||
|
||||
fps = len(frame_info)/(tm.et/1e9)
|
||||
assert fps >= getenv("ASSERT_FPS", 0), f"HEVC decode too slow: {fps:.2f} fps"
|
||||
|
||||
# validation
|
||||
if getenv("VALIDATE", 0):
|
||||
import pickle
|
||||
|
||||
Reference in New Issue
Block a user