diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index 87ebd24f4f..fb935ac04b 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -33,6 +33,8 @@ def get_fuzz_rawbufs(lin): data = np.random.randint(-100, 100, size=rawbuf.size, dtype=rawbuf.dtype.np) elif rawbuf.dtype == dtypes.bool: data = np.random.choice([True, False], size=rawbuf.size) + elif rawbuf.dtype == dtypes.half: + data = np.random.uniform(-1, 1, size=rawbuf.size).astype(dtype=rawbuf.dtype.np) else: data = np.random.uniform(-10, 10, size=rawbuf.size).astype(dtype=rawbuf.dtype.np) rawbuf.copyin(Tensor(data).realize().lazydata.realized.as_buffer()) @@ -111,7 +113,7 @@ def compare_linearizer(lin: Linearizer, rawbufs=None, var_vals=None, ground_trut return ("PASS", rawbufs, var_vals, ground_truth,) -def fuzz_linearizer(lin: Linearizer): +def fuzz_linearizer(lin: Linearizer, rtol=1e-2, atol=1e-2): SEED = getenv("SEED", 42) random.seed(SEED) np.random.seed(SEED) @@ -153,7 +155,7 @@ def fuzz_linearizer(lin: Linearizer): if not FUZZ_ALL_ACTIONS: print(test_lin.colored_shape()) - (msg, rawbufs, var_vals, ground_truth) = compare_linearizer(test_lin, rawbufs, var_vals, ground_truth) + (msg, rawbufs, var_vals, ground_truth) = compare_linearizer(test_lin, rawbufs, var_vals, ground_truth, rtol=rtol, atol=atol) if msg != "PASS": print(test_lin.ast) print(test_lin.applied_opts) @@ -178,6 +180,8 @@ if __name__ == "__main__": parser.add_argument("--ast", type=str, default=None, help="the ast for the kernel to be optimized") parser.add_argument("--file", type=str, default=None, help="a file containing asts to be optimized, one per line") parser.add_argument("--expected-failures", type=int, default=0, help="the number of expected failed kernels") + parser.add_argument("--rtol", type=float, default=1e-2, help="relative tolerance for numerical comparison") + parser.add_argument("--atol", type=float, default=1e-2, help="absolute tolerance for numerical comparison") args = parser.parse_args() if args.ast is not None: @@ -206,7 +210,7 @@ if __name__ == "__main__": tested += 1 lin = ast_str_to_lin(ast) - fuzz_failures = fuzz_linearizer(lin) + fuzz_failures = fuzz_linearizer(lin, rtol=args.rtol, atol=args.atol) if fuzz_failures: failed_ids.append(i) for k, v in fuzz_failures.items(): for f in v: