计算大型火炬张量的边缘(衍生物)的最快方法
给定一个带有形状的张量(B,C,H,W)
,我想提取空间数据的边缘,即计算 x
, y <
y < /code>
(H,W)
的方向衍生物,并计算幅度 i = sqrt(| x_amplitude | x_amplitude |^2+| y_amplitude |^2)
我的当前实现是如下所示,
row_mat = np.asarray([[0, 0, 0], [1, 0, -1], [0, 0, 0]])
col_mat = row_mat.T
row_mat = row_mat[None, None, :, :] # expand dim to convolve with tensor (batch,channel,width,height)
col_mat = col_mat[None, None, :, :] # expand dim to convolve with tensor (batch,channel,width,height)
def derivative(batch: torch.Tensor) -> torch.Tensor:
"""
uses convolution to perform x and y derivatives
:param batch: input tensor batch
:return: image derivative magnitudes
"""
x_amplitude = ndimage.convolve(batch, row_mat)
y_amplitude = ndimage.convolve(batch, col_mat)
magnitude = np.sqrt(np.abs(x_amplitude) ** 2 + np.abs(y_amplitude) ** 2)
return torch.tensor(magnitude)
我想知道是否有一种更快的方法,因为这种方法实际上是使用衍生物的定义进行的,因此可能存在弊端。
ps 。要测试此问题,您可以使用张量 torch.randn(1000,128,28,28)
,因为这些是我要处理的维度
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
data:image/s3,"s3://crabby-images/d5906/d59060df4059a6cc364216c4d63ceec29ef7fe66" alt="扫码二维码加入Web技术交流群"
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
对于此特定操作,您可能可以通过“手动”来加快速度:
For this specific operation you might be able to speed things up a bit by doing it "manually":