From 43c3f73fbcd7d122dd9d3e376ce9faf16031d5ef Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 5 Jul 2024 11:01:20 -0400 Subject: [PATCH] handcode_bert_opt.py (#5295) similar to handcode_resnet50_opt.py, one file to check bert kernels without dataset. --- examples/handcode_bert_opt.py | 98 +++++++++++++++++++++++++++++++++++ examples/mlperf/helpers.py | 6 ++- 2 files changed, 102 insertions(+), 2 deletions(-) create mode 100644 examples/handcode_bert_opt.py diff --git a/examples/handcode_bert_opt.py b/examples/handcode_bert_opt.py new file mode 100644 index 0000000000..a0e0bc4ee5 --- /dev/null +++ b/examples/handcode_bert_opt.py @@ -0,0 +1,98 @@ +from typing import List +from examples.mlperf.helpers import get_mlperf_bert_model +from tinygrad import Tensor, Device, dtypes, nn +from tinygrad.codegen.linearizer import Linearizer +from tinygrad.device import Compiled +from tinygrad.engine.graph import print_tree +from tinygrad.engine.schedule import create_schedule +from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin +from tinygrad.helpers import DEBUG, ansilen, getenv +from tinygrad.ops import LoadOps, get_lazyop_info +from tinygrad.shape.symbolic import sym_infer + + +if __name__ == "__main__": + if getenv("HALF", 1): + dtypes.default_float = dtypes.half + + mdl = get_mlperf_bert_model() + seen = set() + + # the device we are optimizing for + device: Compiled = Device[Device.DEFAULT] + if getenv("BACKWARD"): + Tensor.training = True + optim = (nn.optim.LAMB if getenv("LAMB") else nn.optim.SGD)(nn.state.get_parameters(mdl)) + print(f"optimizing for {Device.DEFAULT}") + + # fake data + BS = getenv("BS", 2) + input_ids = Tensor.empty((BS, 512), dtype=dtypes.float32) + segment_ids = Tensor.empty((BS, 512), dtype=dtypes.float32) + attention_mask = Tensor.empty((BS, 512), dtype=dtypes.default_float) + masked_positions = Tensor.empty((BS, 512), dtype=dtypes.float32) + masked_lm_ids = Tensor.empty((BS, 512), dtype=dtypes.float32) + masked_lm_weights = Tensor.empty((BS, 512), dtype=dtypes.float32) + next_sentence_labels = Tensor.empty((BS, 1), dtype=dtypes.float32) + + # run model twice to get only what changes, these are the kernels of the model + for i in range(2): + lm_logits, seq_relationship_logits = mdl(input_ids, attention_mask, masked_positions, segment_ids) + targets = [lm_logits.lazydata, seq_relationship_logits.lazydata] + if getenv("BACKWARD"): + optim.zero_grad() + loss = mdl.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels) + # ignore grad norm and loss scaler for now + loss.backward() + targets += [x.lazydata for x in optim.schedule_step()] + sched = create_schedule(targets, seen) + print(f"schedule length {len(sched)}") + sched = [x for x in sched if x.ast[0].op not in LoadOps] + + # focus on one kernel + if getenv("KERNEL", -1) >= 0: sched = sched[getenv("KERNEL", -1):getenv("KERNEL", -1)+1] + + # work with the schedule + total_tm = 0 + running_gflops = 0 + for i,si in enumerate(sched): + ops = sum(get_lazyop_info(ast).flops for ast in si.ast) + + if DEBUG >= 2: + for ast in si.ast: print_tree(ast) + + rawbufs = bufs_from_lin(Linearizer(*si.ast)) + + # "linearize" the op into uops in different ways + lins:List[Linearizer] = [] + + # always try hand coded opt + lin = Linearizer(*si.ast, opts=device.renderer) + lin.hand_coded_optimizations() + lins.append(lin) + + # maybe try tensor cores + lin = Linearizer(*si.ast, opts=device.renderer) + if lin.apply_tensor_cores(): + lins.append(lin) + + # try a beam search + if beam:=getenv("BEAM"): + lin = Linearizer(*si.ast, opts=device.renderer) + lin = beam_search(lin, rawbufs, beam, bool(getenv("BEAM_ESTIMATE", 1))) + lins.append(lin) + + # benchmark the programs + choices = [] + for lin in lins: + tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10) + gflops = sym_infer(ops, {k:k.min for k in lin.ast[0].vars()})*1e-9/tm + choices.append((tm, gflops, lin.linearize())) + + # print all kernels + if DEBUG >= 1: print(f" kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS") + tm, gflops, lin = sorted(choices, key=lambda x: x[0])[0] + total_tm += tm + running_gflops += gflops * tm + print(f"*** {total_tm*1000:7.2f} ms : kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS") + print(f"******* total {total_tm*1000:.2f} ms, {running_gflops/total_tm:6.0f} GFLOPS") diff --git a/examples/mlperf/helpers.py b/examples/mlperf/helpers.py index c8495287be..04588cd804 100644 --- a/examples/mlperf/helpers.py +++ b/examples/mlperf/helpers.py @@ -207,7 +207,7 @@ def get_mlperf_bert_config(): "vocab_size": 30522 } -def get_mlperf_bert_model(checkpoint_path:str): +def get_mlperf_bert_model(checkpoint_path:str=""): from extra.models import bert from examples.mlperf.initializers import LinearBert, EmbeddingBert, LayerNormBert @@ -219,7 +219,9 @@ def get_mlperf_bert_model(checkpoint_path:str): config = get_mlperf_bert_config() if getenv("DISABLE_DROPOUT", 0): config["hidden_dropout_prob"] = config["attention_probs_dropout_prob"] = 0.0 - return BertForPretraining(**config).load_from_pretrained(checkpoint_path) + model = BertForPretraining(**config) + if checkpoint_path: model.load_from_pretrained(checkpoint_path) + return model def get_data_bert(GPUS:list[str], it): data: dict[str, Tensor] = next(it)