mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
@@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
if "NOOPT" not in os.environ: os.environ["NOOPT"] = "1"
|
||||
from tinygrad import Device, nn, Tensor, dtypes
|
||||
from tinygrad import Device, nn, Tensor, dtypes, Variable
|
||||
Device.DEFAULT = "CLANG"
|
||||
from train_gpt2 import GPT, GPTConfig
|
||||
from tinygrad.helpers import dedup, to_function_name, flatten, getenv, GRAPH, GlobalCounters, ansilen, to_function_name
|
||||
@@ -19,13 +19,16 @@ if __name__ == "__main__":
|
||||
#early_sched = create_schedule([x.lazydata for x in nn.state.get_parameters(model)], seen)
|
||||
#print(f"built model {len(early_sched)}")
|
||||
|
||||
#B, T = Variable("B", 1, 128).bind(4), 64 #Variable("T", 1, 1024).bind(64)
|
||||
B, T = 4, 64
|
||||
|
||||
optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=1e-4)
|
||||
warmup_count = getenv("WARMUP", 3)
|
||||
for i in range(warmup_count): # TODO: why does it take three and not two to stablize
|
||||
if i == warmup_count-1: GRAPH.value = getenv("LATEGRAPH")
|
||||
GlobalCounters.reset()
|
||||
X = Tensor.empty(4, 64, dtype=dtypes.int)
|
||||
Y = Tensor.empty(4, 64, dtype=dtypes.int)
|
||||
X = Tensor.empty(4, 64, dtype=dtypes.int).reshape(B, T)
|
||||
Y = Tensor.empty(4, 64, dtype=dtypes.int).reshape(B, T)
|
||||
_, loss = model(X, Y)
|
||||
optimizer.zero_grad()
|
||||
if getenv("BACKWARD", 1):
|
||||
@@ -62,10 +65,13 @@ if __name__ == "__main__":
|
||||
state_dict["adam_v_"+nm] = v
|
||||
named_buffers = {v.lazydata.base.buffer:k.replace(".", "_") for k,v in state_dict.items()}
|
||||
|
||||
c_code = [CLANG_PROGRAM_HEADER]
|
||||
c_code = [CLANG_PROGRAM_HEADER, "#include <stdio.h>", "#include <time.h>", "#include <stdlib.h>"]
|
||||
c_code += [x[1] for x in srcs.values()]
|
||||
|
||||
main = ["int main() {"]
|
||||
main += [" struct timespec tm0; clock_gettime(CLOCK_MONOTONIC, &tm0);"]
|
||||
lst = 0
|
||||
|
||||
all_bufs = []
|
||||
for i,si in enumerate(sched):
|
||||
bufs = [(named_buffers.get(b, f"b{numbered_bufs[b]}"), b) for b in si.outputs+si.inputs]
|
||||
@@ -75,13 +81,17 @@ if __name__ == "__main__":
|
||||
else:
|
||||
print(f"{srcs[si.ast][0]}({', '.join([x[0] for x in bufs])})")
|
||||
main.append(f" {to_function_name(srcs[si.ast][0])}({', '.join([x[0] for x in bufs])});")
|
||||
main.append(f" struct timespec tm{i+1}; clock_gettime(CLOCK_MONOTONIC, &tm{i+1});")
|
||||
main.append(f" printf(\"%10.2f ms + %7.2f ms @ {to_function_name(srcs[si.ast][0])}\\n\"," +\
|
||||
f"((tm{i+1}.tv_sec-tm{0}.tv_sec) + (tm{i+1}.tv_nsec-tm{0}.tv_nsec) / 1e9) * 1e3," +\
|
||||
f"((tm{i+1}.tv_sec-tm{lst}.tv_sec) + (tm{i+1}.tv_nsec-tm{lst}.tv_nsec) / 1e9) * 1e3);")
|
||||
lst = i+1
|
||||
#call = f"{srcs[si.ast][0]}({', '.join(bufs)})"
|
||||
#call += " "*(80-ansilen(call))
|
||||
#print(f"{call} // {i+1}")
|
||||
#print(srcs[si.ast][1])
|
||||
main.append("}")
|
||||
|
||||
for n,b in dedup(all_bufs):
|
||||
c_code.append(f"{b.dtype.name} {n}[{b.size}];")
|
||||
mallocs = [f"{b.dtype.name}* {n} = ({b.dtype.name}*)malloc({b.nbytes});" for n,b in dedup(all_bufs)]
|
||||
|
||||
with open("out.c", "w") as f: f.write('\n'.join(c_code+main))
|
||||
with open("out.c", "w") as f: f.write('\n'.join(c_code+main[0:2]+mallocs+main[2:]))
|
||||
|
||||
65
examples/llm.c/ubench/matmul.c
Normal file
65
examples/llm.c/ubench/matmul.c
Normal file
@@ -0,0 +1,65 @@
|
||||
// clang -Ofast -Wno-unused-result -march=native matmul.c
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <time.h>
|
||||
|
||||
float b52[786432];
|
||||
float b49[196608];
|
||||
float h_0_mlp_c_fc_weight[2359296];
|
||||
float h_0_mlp_c_fc_bias[3072];
|
||||
|
||||
void matmul_forward(float* out,
|
||||
float* inp, float* weight, float* bias,
|
||||
int B, int T, int C, int OC) {
|
||||
// most of the running time is spent here and in matmul_backward
|
||||
// OC is short for "output channels"
|
||||
// inp is (B,T,C), weight is (OC, C), bias is (OC)
|
||||
// out will be (B,T,OC)
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int b = 0; b < B; b++) {
|
||||
for (int t = 0; t < T; t++) {
|
||||
float* out_bt = out + b * T * OC + t * OC;
|
||||
float* inp_bt = inp + b * T * C + t * C;
|
||||
for (int o = 0; o < OC; o++) {
|
||||
float val = (bias != NULL) ? bias[o] : 0.0f;
|
||||
float* wrow = weight + o*C;
|
||||
for (int i = 0; i < C; i++) {
|
||||
val += inp_bt[i] * wrow[i];
|
||||
}
|
||||
out_bt[o] = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void r_256_3072_768(float* restrict data0, const float* restrict data1, const float* restrict data2, const float* restrict data3) {
|
||||
for (int ridx0 = 0; ridx0 < 256; ridx0++) {
|
||||
for (int ridx1 = 0; ridx1 < 3072; ridx1++) {
|
||||
float acc0 = 0.0f;
|
||||
float val0 = data3[ridx1];
|
||||
for (int ridx2 = 0; ridx2 < 768; ridx2++) {
|
||||
float val1 = data1[(ridx0*768)+ridx2];
|
||||
float val2 = data2[(ridx1*768)+ridx2];
|
||||
acc0 = ((val1*val2)+acc0);
|
||||
}
|
||||
data0[(ridx0*3072)+ridx1] = (acc0+val0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
for (int i = 0; i < 5; i++) {
|
||||
struct timespec t1, t2, t3;
|
||||
clock_gettime(CLOCK_MONOTONIC, &t1);
|
||||
r_256_3072_768(b52, b49, h_0_mlp_c_fc_weight, h_0_mlp_c_fc_bias);
|
||||
clock_gettime(CLOCK_MONOTONIC, &t2);
|
||||
matmul_forward(b52, b49, h_0_mlp_c_fc_weight, h_0_mlp_c_fc_bias, 4, 64, 768, 3072);
|
||||
clock_gettime(CLOCK_MONOTONIC, &t3);
|
||||
double time_gen = (t2.tv_sec - t1.tv_sec) + (t2.tv_nsec - t1.tv_nsec) / 1e9;
|
||||
double time_real = (t3.tv_sec - t2.tv_sec) + (t3.tv_nsec - t2.tv_nsec) / 1e9;
|
||||
printf("%.2f ms gen vs %.2f ms reference\n", time_gen*1e3, time_real*1e3);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user