根据字典替换NumPy数组中的值,并避免新值和键之间的重叠


问题内容

我想基于python中的以下字典替换2D numpy数组中的值:

code    region
334     0
4       22
8       31
12      16
16      17
24      27
28      18
32      21
36       1

我想在numpy2D数组中找到匹配code并替换为region列中相应值的单元格。问题在于,这将导致替换code = 12region = 16和在下一行中,所有值为16的单元格(包括刚刚被赋值为16的单元格)都将被替换为值17。如何防止这种情况?


问题答案:

这是一个矢量化的矢量,其依据是np.searchsorted要追溯数组中每个键的位置,然后进行替换,请在这里原谅几乎是
性别歧视的 函数名(尽管无济于事)-

def replace_with_dict(ar, dic):
    # Extract out keys and values
    k = np.array(list(dic.keys()))
    v = np.array(list(dic.values()))

    # Get argsort indices
    sidx = k.argsort()

    # Drop the magic bomb with searchsorted to get the corresponding
    # places for a in keys (using sorter since a is not necessarily sorted).
    # Then trace it back to original order with indexing into sidx
    # Finally index into values for desired output.
    return v[sidx[np.searchsorted(k,ar,sorter=sidx)]]

样品运行-

In [82]: dic ={334:0, 4:22, 8:31, 12:16, 16:17, 24:27, 28:18, 32:21, 36:1}
    ...: 
    ...: np.random.seed(0)
    ...: a = np.random.choice(dic.keys(), 20)
    ...:

In [83]: a
Out[83]: 
array([ 28,  16,  32,  32, 334,  32,  28,   4,   8, 334,  12,  36,  36,
        24,  12, 334, 334,  36,  24,  28])

In [84]: replace_with_dict(a, dic)
Out[84]: 
array([18, 17, 21, 21,  0, 21, 18, 22, 31,  0, 16,  1,  1, 27, 16,  0,  0,
        1, 27, 18])

改善

对于大型数组,一种更快的方法是对值和键数组进行排序,然后searchsorted不使用sorter,就像这样-

def replace_with_dict2(ar, dic):
    # Extract out keys and values
    k = np.array(list(dic.keys()))
    v = np.array(list(dic.values()))

    # Get argsort indices
    sidx = k.argsort()

    ks = k[sidx]
    vs = v[sidx]
    return vs[np.searchsorted(ks,ar)]

运行时测试-

In [91]: dic ={334:0, 4:22, 8:31, 12:16, 16:17, 24:27, 28:18, 32:21, 36:1}
    ...: 
    ...: np.random.seed(0)
    ...: a = np.random.choice(dic.keys(), 20000)

In [92]: out1 = replace_with_dict(a, dic)
    ...: out2 = replace_with_dict2(a, dic)
    ...: print np.allclose(out1, out2)
True

In [93]: %timeit replace_with_dict(a, dic)
1000 loops, best of 3: 453 µs per loop

In [95]: %timeit replace_with_dict2(a, dic)
1000 loops, best of 3: 341 µs per loop

所有数组元素都不在字典中时的一般情况

如果不能保证输入数组中的所有元素都在字典中,则我们需要做一些工作,如下所示-

def replace_with_dict2_generic(ar, dic, assume_all_present=True):
    # Extract out keys and values
    k = np.array(list(dic.keys()))
    v = np.array(list(dic.values()))

    # Get argsort indices
    sidx = k.argsort()

    ks = k[sidx]
    vs = v[sidx]
    idx = np.searchsorted(ks,ar)

    if assume_all_present==0:
        idx[idx==len(vs)] = 0
        mask = ks[idx] == ar
        return np.where(mask, vs[idx], ar)
    else:
        return vs[idx]

样品运行-

In [163]: dic ={334:0, 4:22, 8:31, 12:16, 16:17, 24:27, 28:18, 32:21, 36:1}
     ...: 
     ...: np.random.seed(0)
     ...: a = np.random.choice(dic.keys(), (20))
     ...: a[-1] = 400

In [165]: a
Out[165]: 
array([ 28,  16,  32,  32, 334,  32,  28,   4,   8, 334,  12,  36,  36,
        24,  12, 334, 334,  36,  24, 400])

In [166]: replace_with_dict2_generic(a, dic, assume_all_present=False)
Out[166]: 
array([ 18,  17,  21,  21,   0,  21,  18,  22,  31,   0,  16,   1,   1,
        27,  16,   0,   0,   1,  27, 400])