diff --git a/examples/efficientnet.py b/examples/efficientnet.py index cf45261e5d..b6c8f4c37e 100644 --- a/examples/efficientnet.py +++ b/examples/efficientnet.py @@ -13,9 +13,7 @@ np.set_printoptions(suppress=True) from tinygrad.tensor import Tensor from tinygrad.utils import fetch - -# BatchNorm2D and swish -from tinygrad.nn import * +from tinygrad.nn import BatchNorm2D class MBConvBlock: def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio): @@ -47,14 +45,14 @@ class MBConvBlock: def __call__(self, inputs): x = inputs if self._expand_conv: - x = swish(self._bn0(x.conv2d(self._expand_conv))) + x = self._bn0(x.conv2d(self._expand_conv)).swish() x = x.pad2d(padding=self.pad) x = x.conv2d(self._depthwise_conv, stride=self.strides, groups=self._depthwise_conv.shape[0]) - x = swish(self._bn1(x)) + x = self._bn1(x).swish() # has_se x_squeezed = x.avg_pool2d(kernel_size=x.shape[2:4]) - x_squeezed = swish(x_squeezed.conv2d(self._se_reduce).add(self._se_reduce_bias.reshape(shape=[1, -1, 1, 1]))) + x_squeezed = x_squeezed.conv2d(self._se_reduce).add(self._se_reduce_bias.reshape(shape=[1, -1, 1, 1])).swish() x_squeezed = x_squeezed.conv2d(self._se_expand).add(self._se_expand_bias.reshape(shape=[1, -1, 1, 1])) x = x.mul(x_squeezed.sigmoid()) @@ -124,11 +122,11 @@ class EfficientNet: def forward(self, x): x = x.pad2d(padding=(0,1,0,1)) - x = swish(self._bn0(x.conv2d(self._conv_stem, stride=2))) + x = self._bn0(x.conv2d(self._conv_stem, stride=2)).swish() for block in self._blocks: #print(x.shape) x = block(x) - x = swish(self._bn1(x.conv2d(self._conv_head))) + x = self._bn1(x.conv2d(self._conv_head)).swish() x = x.avg_pool2d(kernel_size=x.shape[2:4]) x = x.reshape(shape=(-1, x.shape[1])) #x = x.dropout(0.2) diff --git a/tinygrad/nn.py b/tinygrad/nn.py index 1ce179667c..bf0a0cdd6d 100644 --- a/tinygrad/nn.py +++ b/tinygrad/nn.py @@ -1,8 +1,5 @@ from tinygrad.tensor import Tensor -def swish(x): - return x.mul(x.sigmoid()) - class BatchNorm2D: def __init__(self, sz, eps=0.001): self.eps = eps diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index dd2cd3a679..8bd7ab6e21 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -188,6 +188,9 @@ class Tensor: root = Tensor(np.zeros(self.shape, dtype=self.data.dtype)-1, gpu=self.gpu) return self.mul(y.pow(root)) + def swish(self): + return self.mul(self.sigmoid()) + # An instantiation of the Function is the Context class Function: def __init__(self, *tensors):