From dca084f227eb3266bb7ef36dde197cd6cd63e955 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Thu, 15 Jun 2023 17:11:12 -0700 Subject: [PATCH] minor == to is touchups --- tinygrad/shape/shapetracker.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 4831e0debb..008b4be1b3 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -103,11 +103,11 @@ def merge_views(vm2:View, vm1:View) -> Optional[View]: this_dim = View(vm2.shape, vm2.strides).expr_node(Variable('idx', 0, s-1)*st) if s == 1: new_strides.append(0) # all shape 1 can have stride 0 - elif this_dim.__class__ == NumNode and this_dim.b == 0: + elif this_dim.__class__ is NumNode and this_dim.b == 0: new_strides.append(0) - elif this_dim.__class__ == Variable: + elif this_dim.__class__ is Variable: new_strides.append(1) - elif this_dim.__class__ == MulNode and cast(MulNode, this_dim).a.__class__ == Variable: + elif this_dim.__class__ is MulNode and cast(MulNode, this_dim).a.__class__ is Variable: new_strides.append(this_dim.b) else: if DEBUG >= 4: print("can't simplify", s, this_dim.render()) @@ -152,7 +152,7 @@ def get_unsafe_resize_offset(strides, arg): class ShapeTracker: __slots__ = "views" def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], views:Optional[List[View]]=None): - self.views: List[View] = views if views is not None else ([*cast(ShapeTracker, shape).views] if shape.__class__ == ShapeTracker else [view_from_shape(shape)]) + self.views: List[View] = views if views is not None else ([*cast(ShapeTracker, shape).views] if shape.__class__ is ShapeTracker else [view_from_shape(shape)]) def __repr__(self): return f"ShapeTracker(shape={self.views[-1].shape}, views={self.views})" def copy(self) -> ShapeTracker: return ShapeTracker(self.views[-1].shape, [*self.views]) @@ -199,7 +199,7 @@ class ShapeTracker: return self._expr_idx(idx, valid) def expr_node(self, idx='idx'): - if idx.__class__ == str: idx = Variable(idx, 0, prod(self.shape)-1) + if idx.__class__ is str: idx = Variable(idx, 0, prod(self.shape)-1) return self._expr_idx(self.views[-1].expr_node(idx), self.views[-1].expr_node_mask(idx)) def needs_valid(self) -> bool: @@ -221,7 +221,6 @@ class ShapeTracker: if any([b or e for b, e in arg]): zvarg, mask = get_pad_args(self.shape, arg) self.__unsafe_resize(zvarg, mask=mask) - return self return self def shrink(self, arg: Tuple[Tuple[int, int], ...]):