如何从Python的修补程序中恢复3D图像?
问题内容:
我有一个3D形状的图像DxHxW
。我成功地将图像提取为补丁pdxphxpw
(重叠补丁)。对于每个补丁,我都会做一些处理。现在,我想从已处理的补丁生成图像,以使新图像必须与原始图像具有相同的形状。你能帮我做吗。
这是我提取补丁的代码
def patch_extract_3D(input,patch_shape,xstep=1,ystep=1,zstep=1):
patches_3D = np.lib.stride_tricks.as_strided(input, ((input.shape[0] - patch_shape[0] + 1) / xstep, (input.shape[1] - patch_shape[1] + 1) / ystep,
(input.shape[2] - patch_shape[2] + 1) / zstep, patch_shape[0], patch_shape[1], patch_shape[2]),
(input.strides[0] * xstep, input.strides[1] * ystep,input.strides[2] * zstep, input.strides[0], input.strides[1],input.strides[2]))
patches_3D= patches_3D.reshape(patches_3D.shape[0]*patches_3D.shape[1]*patches_3D.shape[2], patch_shape[0],patch_shape[1],patch_shape[2])
return patches_3D
这是处理补丁程序(只是简单的倍数为2
for i in range(patches_3D.shape[0]):
patches_3D[i]=patches_3D[i];
patches_3D[i]=patches_3D[i]*2;
现在,我需要的是patchs_3D,我想将其重塑为原始图像。谢谢
这是示例代码
patch_shape=[2, 2, 2]
input=np.arange(4*4*6).reshape(4,4,6)
patches_3D=patch_extract_3D(input,patch_shape)
print patches_3D.shape
for i in range(patches_3D.shape[0]):
patches_3D[i]=patches_3D[i]*2
print patches_3D.shape
问题答案:
但是,这样做会相反,因为您的补丁重叠,只有当它们的值在重叠的地方一致时,才能很好地定义
def stuff_patches_3D(out_shape,patches,xstep=12,ystep=12,zstep=12):
out = np.zeros(out_shape, patches.dtype)
patch_shape = patches.shape[-3:]
patches_6D = np.lib.stride_tricks.as_strided(out, ((out.shape[0] - patch_shape[0] + 1) // xstep, (out.shape[1] - patch_shape[1] + 1) // ystep,
(out.shape[2] - patch_shape[2] + 1) // zstep, patch_shape[0], patch_shape[1], patch_shape[2]),
(out.strides[0] * xstep, out.strides[1] * ystep,out.strides[2] * zstep, out.strides[0], out.strides[1],out.strides[2]))
patches_6D[...] = patches.reshape(patches_6D.shape)
return out
更新:这是平均重叠像素的更安全版本:
def stuff_patches_3D(out_shape,patches,xstep=12,ystep=12,zstep=12):
out = np.zeros(out_shape, patches.dtype)
denom = np.zeros(out_shape, patches.dtype)
patch_shape = patches.shape[-3:]
patches_6D = np.lib.stride_tricks.as_strided(out, ((out.shape[0] - patch_shape[0] + 1) // xstep, (out.shape[1] - patch_shape[1] + 1) // ystep,
(out.shape[2] - patch_shape[2] + 1) // zstep, patch_shape[0], patch_shape[1], patch_shape[2]),
(out.strides[0] * xstep, out.strides[1] * ystep,out.strides[2] * zstep, out.strides[0], out.strides[1],out.strides[2]))
denom_6D = np.lib.stride_tricks.as_strided(denom, ((denom.shape[0] - patch_shape[0] + 1) // xstep, (denom.shape[1] - patch_shape[1] + 1) // ystep,
(denom.shape[2] - patch_shape[2] + 1) // zstep, patch_shape[0], patch_shape[1], patch_shape[2]),
(denom.strides[0] * xstep, denom.strides[1] * ystep,denom.strides[2] * zstep, denom.strides[0], denom.strides[1],denom.strides[2]))
np.add.at(patches_6D, tuple(x.ravel() for x in np.indices(patches_6D.shape)), patches.ravel())
np.add.at(denom_6D, tuple(x.ravel() for x in np.indices(patches_6D.shape)), 1)
return out/denom