diff --git a/test/external/external_benchmark_multitensor_allreduce.py b/test/external/external_benchmark_multitensor_allreduce.py index 6867bcaa8f..92fb5ead3c 100644 --- a/test/external/external_benchmark_multitensor_allreduce.py +++ b/test/external/external_benchmark_multitensor_allreduce.py @@ -44,8 +44,8 @@ def main(): else: sz = getenv("SZ", 1000) * 10**6 # size of data on each gpu print(f"Using {sz/10**9:.2f} GB of numbers on each of {n_gpus} GPUs, {n_gpus*sz/10**9:.2f} GB total.") - (ring_gflops, ring_gbs, ring_secs) = run(sz, use_ring=True) - (naive_gflops, naive_gbs, naive_secs) = run(sz, use_ring=False) + (ring_gflops, ring_gbs, ring_secs) = run(sz, use_ring=True, n_gpus=n_gpus) + (naive_gflops, naive_gbs, naive_secs) = run(sz, use_ring=False, n_gpus=n_gpus) print(f"Ring:\n {ring_secs:.6f} seconds/iter\n {ring_gflops:.2f} GFLOP/s\n {ring_gbs:.2f} GB/s") print(f"Naive:\n {naive_secs:.6f} seconds/iter\n {naive_gflops:.2f} GFLOP/s\n {naive_gbs:.2f} GB/s")