For example, I'm trying to view the implementation of RoI Pooling in pytorch.
Here is a code fragment showing how to use RoIPool in pytorch
import torch
from torchvision.ops.roi_pool import RoIPool
device = torch.device('cuda')
# create feature layer, proposals and targets
num_proposals = 10
feature_map = torch.randn(1, 64, 32, 32)
proposals = torch.zeros((num_proposals, 4))
proposals[:, 0] = torch.randint(0, 16, (num_proposals,))
proposals[:, 1] = torch.randint(0, 16, (num_proposals,))
proposals[:, 2] = torch.randint(16, 32, (num_proposals,))
proposals[:, 3] = torch.randint(16, 32, (num_proposals,))
roi_pool_obj = RoIPool(3, 2**-1)
roi_pool = roi_pool_obj(feature_map, [proposals])
I'm using pychram, so when I follow RoIPool
from the second line, it opens a file located at ~/anaconda3/envs/CV/lib/python3.8/site-package/torchvision/ops/roi_pool.py
, which is exactly the same as codes in the documentation.
I pasted the code below without documentations.
from typing import List, Union
import torch
from torch import nn, Tensor
from torch.jit.annotations import BroadcastingList2
from torch.nn.modules.utils import _pair
from torchvision.extension import _assert_has_ops
from ..utils import _log_api_usage_once
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
def roi_pool(
input: Tensor,
boxes: Union[Tensor, List[Tensor]],
output_size: BroadcastingList2[int],
spatial_scale: float = 1.0,
) -> Tensor:
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(roi_pool)
_assert_has_ops()
check_roi_boxes_shape(boxes)
rois = boxes
output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois)
output, _ = torch.ops.torchvision.roi_pool(input, rois, spatial_scale, output_size[0], output_size[1])
return output
class RoIPool(nn.Module):
def __init__(self, output_size: BroadcastingList2[int], spatial_scale: float):
super().__init__()
_log_api_usage_once(self)
self.output_size = output_size
self.spatial_scale = spatial_scale
def forward(self, input: Tensor, rois: Tensor) -> Tensor:
return roi_pool(input, rois, self.output_size, self.spatial_scale)
def __repr__(self) -> str:
s = f"{self.__class__.__name__}(output_size={self.output_size}, spatial_scale={self.spatial_scale})"
return s
So, in the code example:
When running roi_pool_obj = RoIPool(3, 2**-1)
it will create an instance of RoIPool
by calling its __init__
method, which only initialized two instance variables;
When running roi_pool = roi_pool_obj(feature_map, [proposals])
, it must have called the forward()
method (but I don't know how) which then called the roi_pool()
function above;
When running the roi_pool()
function, it did some checking first and then computed output with the line output, _ = torch.ops.torchvision.roi_pool(input, rois, spatial_scale, output_size[0], output_size[1])
.
But this doesn't show details of how roi_pool is implemented and pycharm showed Cannot find declaration to go to
when I tried to follow torch.ops.torchvision.roi_pool
.
To summarize, I have two questions:
- How does the
forward()
called by runningroi_pool = roi_pool_obj(feature_map, [proposals])
? - How can I view the source code of
torch.ops.torchvision.roi_pool
or where is the file containing it's implementaion located?
Last but not least, I've just started reading source code which is pretty difficult for me. I'd appreciate it if you can also provide some advice or tutorials.
CodePudding user response:
RoIPool
is a subclass of torch.nn.Module. Source code:
- nn.Module defines
__call__
method which in turn callsforward
method. Source code:
- When you executing
roi_pool = roi_pool_obj(feature_map, [proposals])
statement the__call__
method uses theforward()
of RoiPool. Source code:
RoiPool.forward
callstorch.ops.torchvision.roi_pool
.
- ops is a object which loads native libraries implemented in c :
https://github.com/pytorch/pytorch/blob/b2311192e6c4745aac3fdd774ac9d56a36b396d4/torch/_ops.py#L537
so when you call torch.ops.torchvision
it will use torchvision
library.
- Here the
roi_pool
function is registered:
- Here you can find the actual implementation of
rol_pool
CPU: