SIFT的PyTorch实现
生活大爆炸

SIFT的PyTorch实现

hualala
2022-03-16 / 0 评论 / 287 阅读 / 正在检测是否收录...
温馨提示:
本文最后更新于2022年03月16日,已超过765天没有更新,若内容或图片失效,请留言反馈。

代码

def getPoolingKernel(kernel_size=25):
    half_size = float(kernel_size) / 2.0
    xc2 = []
    for i in range(kernel_size):
        xc2.append(half_size - abs(float(i) + 0.5 - half_size))
    xc2 = np.array(xc2)
    kernel = np.outer(xc2.T, xc2)
    kernel = kernel / (half_size ** 2)
    return kernel

def get_bin_weight_kernel_size_and_stride(patch_size, num_spatial_bins):
    ks = 2 * int(patch_size / (num_spatial_bins + 1));
    stride = patch_size // num_spatial_bins
    pad = ks // 4
    return ks, stride, pad

class SIFTNet(nn.Module):
    def CircularGaussKernel(self, kernlen=21, circ=True, sigma_type='hesamp'):
        halfSize = float(kernlen) / 2.
        r2 = float(halfSize ** 2)
        if sigma_type == 'hesamp':
            sigma_mul_2 = 0.9 * r2
        elif sigma_type == 'vlfeat':
            sigma_mul_2 = kernlen ** 2
        else:
            raise ValueError('Unknown sigma_type', sigma_type, 'try hesamp or vlfeat')
        disq = 0
        kernel = np.zeros((kernlen, kernlen))
        for y in range(kernlen):
            for x in range(kernlen):
                disq = (y - halfSize + 0.5) ** 2 + (x - halfSize + 0.5) ** 2
                kernel[y, x] = math.exp(-disq / sigma_mul_2)
                if circ and (disq >= r2):
                    kernel[y, x] = 0.
        return kernel

    def __repr__(self):
        return self.__class__.__name__ + '(' + 'num_ang_bins=' + str(self.num_ang_bins) + \
               ', ' + 'num_spatial_bins=' + str(self.num_spatial_bins) + \
               ', ' + 'patch_size=' + str(self.patch_size) + \
               ', ' + 'rootsift=' + str(self.rootsift) + \
               ', ' + 'sigma_type=' + str(self.sigma_type) + \
               ', ' + 'mask_type=' + str(self.mask_type) + \
               ', ' + 'clipval=' + str(self.clipval) + ')'

    def __init__(self,
                 patch_size=65,
                 num_ang_bins=8,
                 num_spatial_bins=4,
                 clipval=0.2,
                 rootsift=False,
                 mask_type='CircularGauss',
                 sigma_type='hesamp'):
        super(SIFTNet, self).__init__()
        self.eps = 1e-10
        self.num_ang_bins = num_ang_bins
        self.num_spatial_bins = num_spatial_bins
        self.clipval = clipval
        self.rootsift = rootsift
        self.mask_type = mask_type
        self.patch_size = patch_size
        self.sigma_type = sigma_type

        if self.mask_type == 'CircularGauss':
            self.gk = torch.from_numpy(
                self.CircularGaussKernel(kernlen=patch_size, circ=True, sigma_type=sigma_type).astype(np.float32))
        elif self.mask_type == 'Gauss':
            self.gk = torch.from_numpy(
                self.CircularGaussKernel(kernlen=patch_size, circ=False, sigma_type=sigma_type).astype(np.float32))
        elif self.mask_type == 'Uniform':
            self.gk = torch.ones(patch_size, patch_size).float() / float(patch_size * patch_size)
        else:
            raise ValueError(self.mask_type, 'is unknown mask type')

        self.bin_weight_kernel_size, self.bin_weight_stride, self.pad = get_bin_weight_kernel_size_and_stride(
            patch_size, num_spatial_bins)
        self.gx = nn.Conv2d(1, 1, kernel_size=(1, 3), bias=False)
        self.gx.weight.data = torch.tensor(np.array([[[[-1, 0, 1]]]], dtype=np.float32))

        self.gy = nn.Conv2d(1, 1, kernel_size=(3, 1), bias=False)
        self.gy.weight.data = torch.from_numpy(np.array([[[[-1], [0], [1]]]], dtype=np.float32))
        nw = getPoolingKernel(kernel_size=self.bin_weight_kernel_size)

        self.pk = nn.Conv2d(1, 1, kernel_size=(nw.shape[0], nw.shape[1]),
                            stride=(self.bin_weight_stride, self.bin_weight_stride),
                            padding=(self.pad, self.pad),
                            bias=False)
        new_weights = np.array(nw.reshape((1, 1, nw.shape[0], nw.shape[1])))
        self.pk.weight.data = torch.from_numpy(new_weights.astype(np.float32))
        return

    def forward(self, x):
        gx = self.gx(F.pad(x, (1, 1, 0, 0), 'replicate'))
        gy = self.gy(F.pad(x, (0, 0, 1, 1), 'replicate'))
        mag = torch.sqrt(gx * gx + gy * gy + self.eps)
        ori = torch.atan2(gy, gx + self.eps)
        mag = mag * self.gk.expand_as(mag).to(mag.device)
        o_big = (ori + 2.0 * math.pi) / (2.0 * math.pi) * float(self.num_ang_bins)
        bo0_big_ = torch.floor(o_big)
        wo1_big_ = o_big - bo0_big_
        bo0_big = bo0_big_ % self.num_ang_bins
        bo1_big = (bo0_big + 1) % self.num_ang_bins
        wo0_big = (1.0 - wo1_big_) * mag
        wo1_big = wo1_big_ * mag
        ang_bins = []
        for i in range(0, self.num_ang_bins):
            out = self.pk((bo0_big == i).float() * wo0_big + (bo1_big == i).float() * wo1_big)
            ang_bins.append(out)
        ang_bins = torch.cat(ang_bins, 1)
        ang_bins = ang_bins.view(ang_bins.size(0), -1)
        ang_bins = F.normalize(ang_bins, p=2)
        ang_bins = torch.clamp(ang_bins, 0., float(self.clipval))
        ang_bins = F.normalize(ang_bins, p=2)
        if self.rootsift:
            ang_bins = torch.sqrt(F.normalize(ang_bins, p=1) + 1e-10)
        return ang_bins

class SIFTNetFeature2D:
    def __init__(self):
        self.model = SIFTNet(patch_size=32, mask_type='Gauss', sigma_type='hesamp')
        device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
        torch.set_grad_enabled(False)
        if torch.cuda.is_available():
            self.model.cuda()
        self.model.eval()

    def patch2des_once(self, patches):
        if len(patches.shape) == 4:
            new_patches = []
            for i in range(len(patches)):
                new_patches.append(cv2.cvtColor(patches[i] * 255, cv2.COLOR_BGR2GRAY) / 255.0)
            new_patches = torch.from_numpy(numpy.array(new_patches).astype(np.float32)).to(device)
            data_a = new_patches.permute(0, 2, 1).unsqueeze(dim=1)
        else:
            data_a = torch.from_numpy(numpy.array(patches).astype(np.float32)).to(device).unsqueeze(dim=1)
        with torch.no_grad():
            out_a = self.model(data_a)
        return out_a

    def compute(self, patches):
        batch_size = 128
        dim = 128
        descriptors_for_net = np.zeros((len(patches), dim), dtype=np.float32)
        for i in range(0, len(patches), batch_size):
            data_a = patches[i: i + batch_size, :, :].astype(np.float32)
            out_a = self.patch2des_once(data_a)
            descriptors_for_net[i: i + batch_size] = out_a.cpu().detach().numpy().reshape(-1, dim)
        return descriptors_for_net / np.max(np.abs(descriptors_for_net))

用法

sift_net = SIFTNetFeature2D()
sift_net.compute(patches)

效果
比OpenCV的经典SIFT效果略差。

其他
当然也可以直接使用kornia写好的,文档网址:
https://kornia.readthedocs.io/en/latest/_modules/kornia/feature/siftdesc.html#SIFTDescriptor

0

评论 (0)

取消