From f178d23ff34ff2531879fc5c2b1e3692374cb01d Mon Sep 17 00:00:00 2001 From: George Hotz Date: Mon, 2 Nov 2020 08:25:32 -0800 Subject: [PATCH] gpu relu is good --- test/test_ops.py | 2 ++ tinygrad/opsgpu.py | 8 +++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index cf17732a2a..3d07f1fa17 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -51,6 +51,8 @@ class TestOps(unittest.TestCase): helper_test_op([(45,65), (45,65)], lambda x,y: x**y, Tensor.pow, gpu=self.gpu) def test_sqrt(self): helper_test_op([(45,65)], lambda x: x.sqrt(), Tensor.sqrt, gpu=self.gpu) + def test_relu(self): + helper_test_op([(45,65)], lambda x: x.relu(), Tensor.relu, gpu=self.gpu) def test_dot(self): helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-5, gpu=self.gpu) diff --git a/tinygrad/opsgpu.py b/tinygrad/opsgpu.py index cf3d63027d..00cda0b62b 100644 --- a/tinygrad/opsgpu.py +++ b/tinygrad/opsgpu.py @@ -180,12 +180,14 @@ register('matmul', Dot, gpu=True) class ReLU(Function): @staticmethod - def forward(ctx, x): - return unary_op(ctx, 'res_g[gid] = min(a_g[gid], (float)0.);', x) + def forward(ctx, input): + ctx.save_for_backward(input) + return unary_op(ctx, 'res_g[gid] = max(a_g[gid], (float)0.);', input) @staticmethod def backward(ctx, grad_output): - return grad_output + input, = ctx.saved_tensors + return binary_op(ctx, 'res_g[gid] = a_g[gid] * (b_g[gid] >= 0);', grad_output, input) register('relu', ReLU, gpu=True) class LogSoftmax(Function):