diff --git a/models/mask_rcnn.py b/models/mask_rcnn.py index d7e815782b..dba94e3029 100644 --- a/models/mask_rcnn.py +++ b/models/mask_rcnn.py @@ -828,7 +828,6 @@ def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling bin_size_w = roi_width / pooled_width exact_sampling = sampling_ratio > 0 - roi_bin_grid_h = sampling_ratio if exact_sampling else (roi_height / pooled_height).ceil() roi_bin_grid_w = sampling_ratio if exact_sampling else (roi_width / pooled_width).ceil() @@ -907,7 +906,6 @@ class LevelMapper: self.eps = eps def __call__(self, boxlists): - # TODO: remove numpy s = Tensor.sqrt(Tensor.cat(*[boxlist.area() for boxlist in boxlists])) target_lvls = (self.lvl0 + Tensor.log2(s / self.s0 + self.eps)).floor() target_lvls = target_lvls.clip(min_=self.k_min, max_=self.k_max) @@ -954,21 +952,18 @@ class Pooler: return self.poolers[0](x[0], rois) levels = self.map_levels(boxes) - - num_rois = rois.shape[0] - num_channels = x[0].shape[1] - output_size = self.output_size[0] - - result = np.zeros( - (num_rois, num_channels, output_size, output_size), dtype=x[0].dtype.np - ) + results = [] + all_idxs = [] for level, (per_level_feature, pooler) in enumerate(zip(x, self.poolers)): + # this is fine because no grad will flow through index idx_in_level = [idx for idx, x in enumerate((levels.numpy() == level)) if x != 0] if len(idx_in_level) > 0: rois_per_level = tensor_gather(rois, idx_in_level) - result[idx_in_level] = pooler(per_level_feature, rois_per_level).numpy() + pooler_output = pooler(per_level_feature, rois_per_level) + all_idxs.extend(idx_in_level) + results.append(pooler_output) - return Tensor(result, dtype=x[0].dtype, device=x[0].device) + return tensor_gather(Tensor.cat(*results), [x[0] for x in sorted({i:idx for i, idx in enumerate(all_idxs)}.items(), key=lambda x: x[1])]) class FPNPredictor: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 47e48a3ea4..2cc3d64db5 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -515,7 +515,7 @@ class Tensor: def abs(self): return self.relu() + (-self).relu() def sign(self): return self / (self.abs() + 1e-10) def reciprocal(self): return 1.0/self - def floor(self): i = self.cast(dtypes.int32).realize(); cond=i > self; return cond * (i - 1) + (1.0 - cond) * i + def floor(self): i = self.cast(dtypes.int32); return (self>0).where(i, i-1) def ceil(self): return -1 * (-1 * self).floor() # ***** activation functions (unary) *****