mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
llama: mllog (#15502)
This commit is contained in:
@@ -1287,6 +1287,9 @@ def train_llama3():
|
||||
from examples.mlperf.lr_schedulers import CosineAnnealingLRWithWarmup
|
||||
from examples.mlperf.optim import GradAccClipAdamW
|
||||
|
||||
INITMLPERF = getenv("INITMLPERF")
|
||||
RUNMLPERF = getenv("RUNMLPERF")
|
||||
LOGMLPERF = getenv("LOGMLPERF")
|
||||
BENCHMARK = getenv("BENCHMARK")
|
||||
|
||||
config = {}
|
||||
@@ -1309,15 +1312,60 @@ def train_llama3():
|
||||
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 16)
|
||||
EVAL_TARGET = config["EVAL_TARGET"] = getenv("EVAL_TARGET", 5.6)
|
||||
|
||||
# LR=1e-4 TRAIN_ON_VAL=1 DEFAULT_FLOAT=bfloat16 JITBEAM=2 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=1B WARMUP_STEPS=36 DECAY_STEPS=360 SEQLEN=512 PYTHONPATH=. DEV=AMD AMD_LLVM=0 MODEL=llama3 python3 examples/mlperf/model_train.py
|
||||
# trains to 7
|
||||
if LOGMLPERF:
|
||||
from mlperf_logging import mllog
|
||||
import mlperf_logging.mllog.constants as mllog_constants
|
||||
|
||||
mllog.config(filename=f"result_llama31_{SEED}.log")
|
||||
mllog.config(root_dir=Path(__file__).parents[3].as_posix())
|
||||
MLLOGGER = mllog.get_mllogger()
|
||||
MLLOGGER.logger.propagate = False
|
||||
|
||||
LLAMA_BENCHMARK = mllog_constants.LLAMA31_405B if getenv("LLAMA3_SIZE", "8B") == "405B" else mllog_constants.LLAMA31_8B
|
||||
|
||||
if INITMLPERF:
|
||||
assert BENCHMARK, "BENCHMARK must be set for INITMLPERF"
|
||||
MLLOGGER.event(key=mllog_constants.SUBMISSION_ORG, value="tinycorp")
|
||||
MLLOGGER.event(key=mllog_constants.SUBMISSION_PLATFORM, value=getenv("SUBMISSION_PLATFORM", "tinybox"))
|
||||
MLLOGGER.event(key=mllog_constants.SUBMISSION_DIVISION, value=mllog_constants.CLOSED)
|
||||
MLLOGGER.event(key=mllog_constants.SUBMISSION_STATUS, value=mllog_constants.ONPREM)
|
||||
|
||||
MLLOGGER.event(key=mllog_constants.SUBMISSION_BENCHMARK, value=LLAMA_BENCHMARK)
|
||||
|
||||
diskcache_clear()
|
||||
MLLOGGER.event(key=mllog_constants.CACHE_CLEAR, value=True)
|
||||
MLLOGGER.start(key=mllog_constants.INIT_START, value=None)
|
||||
|
||||
if RUNMLPERF:
|
||||
MLLOGGER.start(key=mllog_constants.RUN_START, value=None)
|
||||
MLLOGGER.event(key=mllog_constants.SEED, value=SEED)
|
||||
|
||||
MLLOGGER.event(key=mllog_constants.GLOBAL_BATCH_SIZE, value=GBS)
|
||||
MLLOGGER.event(key=mllog_constants.MAX_SEQUENCE_LENGTH, value=SEQLEN)
|
||||
MLLOGGER.event(key=mllog_constants.MAX_STEPS, value=MAX_STEPS)
|
||||
MLLOGGER.event(key=mllog_constants.GRADIENT_ACCUMULATION_STEPS, value=grad_acc)
|
||||
MLLOGGER.event(key=mllog_constants.EVAL_SAMPLES, value=EVAL_SAMPLES)
|
||||
MLLOGGER.event(key=mllog_constants.TRAIN_SAMPLES, value=SAMPLES)
|
||||
|
||||
MLLOGGER.event(key=mllog_constants.OPT_NAME, value=mllog_constants.ADAMW)
|
||||
MLLOGGER.event(key=mllog_constants.OPT_BASE_LR, value=LR)
|
||||
MLLOGGER.event(key=mllog_constants.OPT_END_LR, value=END_LR)
|
||||
MLLOGGER.event(key=mllog_constants.OPT_ADAMW_BETA_1, value=0.9)
|
||||
MLLOGGER.event(key=mllog_constants.OPT_ADAMW_BETA_2, value=0.95)
|
||||
MLLOGGER.event(key=mllog_constants.OPT_ADAMW_EPSILON, value=1e-5)
|
||||
MLLOGGER.event(key=mllog_constants.OPT_ADAMW_WEIGHT_DECAY, value=0.1)
|
||||
MLLOGGER.event(key=mllog_constants.OPT_LR_WARMUP_STEPS, value=WARMUP_STEPS)
|
||||
MLLOGGER.event(key=mllog_constants.NUM_WARMUP_STEPS, value=WARMUP_STEPS)
|
||||
MLLOGGER.event(key=mllog_constants.OPT_LR_DECAY_STEPS, value=MAX_STEPS - WARMUP_STEPS)
|
||||
MLLOGGER.event(key=mllog_constants.OPT_GRADIENT_CLIP_NORM, value=1.0)
|
||||
else:
|
||||
MLLOGGER = None
|
||||
|
||||
opt_adamw_beta_1 = 0.9
|
||||
opt_adamw_beta_2 = 0.95
|
||||
opt_adamw_epsilon = 1e-5
|
||||
opt_adamw_weight_decay = 0.1
|
||||
|
||||
opt_gradient_clip_norm = 1.0
|
||||
opt_learning_rate_warmup_steps = WARMUP_STEPS
|
||||
opt_learning_rate_decay_steps = MAX_STEPS - opt_learning_rate_warmup_steps
|
||||
opt_base_learning_rate = LR
|
||||
@@ -1451,6 +1499,11 @@ def train_llama3():
|
||||
train_iter = get_train_iter()
|
||||
i, sequences_seen = resume_ckpt, 0
|
||||
step_times = []
|
||||
|
||||
if MLLOGGER and RUNMLPERF:
|
||||
MLLOGGER.start(key=mllog_constants.EPOCH_START, metadata={mllog_constants.SAMPLES_COUNT: sequences_seen})
|
||||
MLLOGGER.start(key=mllog_constants.BLOCK_START, metadata={mllog_constants.SAMPLES_COUNT: sequences_seen})
|
||||
|
||||
while i < MAX_STEPS:
|
||||
GlobalCounters.reset()
|
||||
actual_gbs = GBS if i >= 2 else BS
|
||||
@@ -1533,6 +1586,10 @@ def train_llama3():
|
||||
tqdm.write(f"evaluating after {sequences_seen} sequences")
|
||||
profile_marker(f"eval @ {i}")
|
||||
|
||||
if MLLOGGER and RUNMLPERF:
|
||||
MLLOGGER.end(key=mllog_constants.BLOCK_STOP, metadata={mllog_constants.SAMPLES_COUNT: sequences_seen})
|
||||
MLLOGGER.start(key=mllog_constants.EVAL_START, metadata={mllog_constants.SAMPLES_COUNT: sequences_seen})
|
||||
|
||||
# run eval
|
||||
eval_losses = []
|
||||
eval_iter = get_eval_iter()
|
||||
@@ -1542,22 +1599,34 @@ def train_llama3():
|
||||
eval_losses += eval_step(tokens).tolist()
|
||||
|
||||
if BENCHMARK and (j+1) == min(BENCHMARK, EVAL_SAMPLES//EVAL_BS):
|
||||
if MLLOGGER and INITMLPERF:
|
||||
MLLOGGER.end(key=mllog_constants.INIT_STOP, value=None)
|
||||
return
|
||||
|
||||
log_perplexity = sum(eval_losses) / len(eval_losses)
|
||||
|
||||
tqdm.write(f"eval log perplexity: {log_perplexity:.4f}")
|
||||
|
||||
if MLLOGGER and RUNMLPERF:
|
||||
MLLOGGER.event(key=mllog_constants.EVAL_ACCURACY, value=log_perplexity, metadata={mllog_constants.SAMPLES_COUNT: sequences_seen})
|
||||
MLLOGGER.end(key=mllog_constants.EVAL_STOP, metadata={mllog_constants.SAMPLES_COUNT: sequences_seen})
|
||||
|
||||
if WANDB:
|
||||
wandb.log({"eval/log_perplexity": log_perplexity, "eval/sequences_seen": sequences_seen})
|
||||
|
||||
if log_perplexity < EVAL_TARGET:
|
||||
tqdm.write(f"target achieved after {sequences_seen} sequences")
|
||||
if MLLOGGER and RUNMLPERF:
|
||||
MLLOGGER.end(key=mllog_constants.EPOCH_STOP, metadata={mllog_constants.SAMPLES_COUNT: sequences_seen})
|
||||
MLLOGGER.event(key=mllog_constants.TRAIN_SAMPLES, value=sequences_seen)
|
||||
MLLOGGER.end(key=mllog_constants.RUN_STOP, metadata={mllog_constants.STATUS: mllog_constants.SUCCESS})
|
||||
if getenv("CKPT"):
|
||||
if not os.path.exists(ckpt_dir := "./ckpts"): os.mkdir(ckpt_dir)
|
||||
fn = f"{ckpt_dir}/llama3.safe"
|
||||
safe_save(get_state_dict(model), fn)
|
||||
break
|
||||
if MLLOGGER and RUNMLPERF:
|
||||
MLLOGGER.start(key=mllog_constants.BLOCK_START, metadata={mllog_constants.SAMPLES_COUNT: sequences_seen})
|
||||
|
||||
def train_stable_diffusion():
|
||||
from extra.models.unet import UNetModel
|
||||
|
||||
Reference in New Issue
Block a user