Python源码示例:torch.remainder()

示例1
def convert_padding_direction(src_tokens, padding_idx, right_to_left=False, left_to_right=False):
    assert right_to_left ^ left_to_right
    pad_mask = src_tokens.eq(padding_idx)
    if not pad_mask.any():
        # no padding, return early
        return src_tokens
    if left_to_right and not pad_mask[:, 0].any():
        # already right padded
        return src_tokens
    if right_to_left and not pad_mask[:, -1].any():
        # already left padded
        return src_tokens
    max_len = src_tokens.size(1)
    range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens)
    num_pads = pad_mask.long().sum(dim=1, keepdim=True)
    if right_to_left:
        index = torch.remainder(range - num_pads, max_len)
    else:
        index = torch.remainder(range + num_pads, max_len)
    return src_tokens.gather(1, index) 
示例2
def convert_padding_direction(
    src_tokens, padding_idx, right_to_left: bool = False, left_to_right: bool = False
):
    assert right_to_left ^ left_to_right
    pad_mask = src_tokens.eq(padding_idx)
    if not pad_mask.any():
        # no padding, return early
        return src_tokens
    if left_to_right and not pad_mask[:, 0].any():
        # already right padded
        return src_tokens
    if right_to_left and not pad_mask[:, -1].any():
        # already left padded
        return src_tokens
    max_len = src_tokens.size(1)
    buffered = torch.empty(0).long()
    if max_len > 0:
        torch.arange(max_len, out=buffered)
    range = buffered.type_as(src_tokens).expand_as(src_tokens)
    num_pads = pad_mask.long().sum(dim=1, keepdim=True)
    if right_to_left:
        index = torch.remainder(range - num_pads, max_len)
    else:
        index = torch.remainder(range + num_pads, max_len)
    return src_tokens.gather(1, index) 
示例3
def convert_padding_direction(src_tokens, padding_idx, right_to_left=False, left_to_right=False):
    assert right_to_left ^ left_to_right
    pad_mask = src_tokens.eq(padding_idx)
    if not pad_mask.any():
        # no padding, return early
        return src_tokens
    if left_to_right and not pad_mask[:, 0].any():
        # already right padded
        return src_tokens
    if right_to_left and not pad_mask[:, -1].any():
        # already left padded
        return src_tokens
    max_len = src_tokens.size(1)
    range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens)
    num_pads = pad_mask.long().sum(dim=1, keepdim=True)
    if right_to_left:
        index = torch.remainder(range - num_pads, max_len)
    else:
        index = torch.remainder(range + num_pads, max_len)
    return src_tokens.gather(1, index) 
示例4
def mod_divide(self, forget_radix, mask):
        # Only called in reverse process. 
        mask = mask.long()
        self.counter -= 1

        buf_mod = torch.remainder(self.curr_buffer, 2**forget_radix).int()
        self.curr_buffer = mask*(self.curr_buffer/(2**forget_radix)) + (1-mask)*self.curr_buffer

        overflowed = self.overflow_detect.__and__(2**(self.counter % 8))[self.counter // 8]
        if overflowed:
            self.curr_buffer = self.past_buffers[:,:,-1]
            if self.past_buffers.size(2) > 1:
                self.past_buffers = self.past_buffers[:,:,:-1]

        return buf_mod

###############################################################################
# Multiply/divide fixed point numbers with buffer
############################################################################### 
示例5
def forward(ctx, h, z, buf, mask, slice_dim=0):
        ctx.save_for_backward(h, z, mask)

        # Shift buffer left, enlarging if needed, then store modulus of h in buffer.
        if buf is not None:
            h_mod = torch.remainder(h[:, slice_dim:], 2**forget_radix)
            buf.overflow_mul(2**forget_radix, mask[:, slice_dim:])
            buf.add(h_mod, mask[:, slice_dim:])

        # Multiply h by z/(2**forget_radix).
        # Have to do extra work in case h is negative.
        sign_bits = h.__and__(sign_bit)  
        one_bits = negative_bits * -1 * torch.clamp(sign_bits, min=-1)
        h = h.__rshift__(forget_radix * mask) 
        h = h.__or__(one_bits * mask)
        h = mask*h*z + (1-mask)*h

        # Store modulus of buffer in h then divide buffer by z.
        if buf is not None:
            buf_mod = buf.mod(z[:,slice_dim:])
            h[:,slice_dim:] = h[:,slice_dim:] + buf_mod*mask[:, slice_dim:]
            buf.div(z[:,slice_dim:], mask[:, slice_dim:])

        return h 
示例6
def forward(self, input, offsets=None, per_sample_weights=None):
        input_q = (input / self.num_collisions).long()
        input_r = torch.remainder(input, self.num_collisions).long()

        embed_q = F.embedding_bag(input_q, self.weight_q, offsets, self.max_norm,
                                  self.norm_type, self.scale_grad_by_freq, self.mode,
                                  self.sparse, per_sample_weights)
        embed_r = F.embedding_bag(input_r, self.weight_r, offsets, self.max_norm,
                                  self.norm_type, self.scale_grad_by_freq, self.mode,
                                  self.sparse, per_sample_weights)

        if self.operation == 'concat':
            embed = torch.cat((embed_q, embed_r), dim=1)
        elif self.operation == 'add':
            embed = embed_q + embed_r
        elif self.operation == 'mult':
            embed = embed_q * embed_r

        return embed 
示例7
def forward(self, input, offsets=None, per_sample_weights=None):
        input_q = (input / self.num_collisions).long()
        input_r = torch.remainder(input, self.num_collisions).long()

        embed_q = F.embedding_bag(input_q, self.weight_q, offsets, self.max_norm,
                                  self.norm_type, self.scale_grad_by_freq, self.mode,
                                  self.sparse, per_sample_weights)
        embed_r = F.embedding_bag(input_r, self.weight_r, offsets, self.max_norm,
                                  self.norm_type, self.scale_grad_by_freq, self.mode,
                                  self.sparse, per_sample_weights)

        if self.operation == 'concat':
            embed = torch.cat((embed_q, embed_r), dim=1)
        elif self.operation == 'add':
            embed = embed_q + embed_r
        elif self.operation == 'mult':
            embed = embed_q * embed_r

        return embed 
示例8
def convert_padding_direction(src_tokens,
                              padding_idx,
                              right_to_left=False,
                              left_to_right=False):
    assert right_to_left ^ left_to_right
    pad_mask = src_tokens.eq(padding_idx)
    if not pad_mask.any():
        # no padding, return early
        return src_tokens
    if left_to_right and not pad_mask[:, 0].any():
        # already right padded
        return src_tokens
    if right_to_left and not pad_mask[:, -1].any():
        # already left padded
        return src_tokens
    max_len = src_tokens.size(1)
    range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens)
    num_pads = pad_mask.long().sum(dim=1, keepdim=True)
    if right_to_left:
        index = torch.remainder(range - num_pads, max_len)
    else:
        index = torch.remainder(range + num_pads, max_len)
    return src_tokens.gather(1, index) 
示例9
def convert_padding_direction(src_tokens, padding_idx, right_to_left=False, left_to_right=False):
    assert right_to_left ^ left_to_right
    pad_mask = src_tokens.eq(padding_idx)
    if not pad_mask.any():
        # no padding, return early
        return src_tokens
    if left_to_right and not pad_mask[:, 0].any():
        # already right padded
        return src_tokens
    if right_to_left and not pad_mask[:, -1].any():
        # already left padded
        return src_tokens
    max_len = src_tokens.size(1)
    range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens)
    num_pads = pad_mask.long().sum(dim=1, keepdim=True)
    if right_to_left:
        index = torch.remainder(range - num_pads, max_len)
    else:
        index = torch.remainder(range + num_pads, max_len)
    return src_tokens.gather(1, index) 
示例10
def convert_padding_direction(
    src_tokens,
    src_lengths,
    padding_idx,
    right_to_left=False,
    left_to_right=False,
):
    assert right_to_left ^ left_to_right
    pad_mask = src_tokens.eq(padding_idx)
    if pad_mask.max() == 0:
        # no padding, return early
        return src_tokens
    max_len = src_tokens.size(1)
    range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens)
    num_pads = pad_mask.long().sum(dim=1, keepdim=True)
    if right_to_left:
        index = torch.remainder(range - num_pads, max_len)
    else:
        index = torch.remainder(range + num_pads, max_len)
    return src_tokens.gather(1, index) 
示例11
def convert_padding_direction(
    src_tokens,
    src_lengths,
    padding_idx,
    right_to_left=False,
    left_to_right=False,
):
    assert right_to_left ^ left_to_right
    pad_mask = src_tokens.eq(padding_idx)
    if pad_mask.max() == 0:
        # no padding, return early
        return src_tokens
    max_len = src_tokens.size(1)
    range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens)
    num_pads = pad_mask.long().sum(dim=1, keepdim=True)
    if right_to_left:
        index = torch.remainder(range - num_pads, max_len)
    else:
        index = torch.remainder(range + num_pads, max_len)
    return src_tokens.gather(1, index) 
示例12
def convert_padding_direction(
    src_tokens, padding_idx, right_to_left: bool = False, left_to_right: bool = False
):
    assert right_to_left ^ left_to_right
    pad_mask = src_tokens.eq(padding_idx)
    if not pad_mask.any():
        # no padding, return early
        return src_tokens
    if left_to_right and not pad_mask[:, 0].any():
        # already right padded
        return src_tokens
    if right_to_left and not pad_mask[:, -1].any():
        # already left padded
        return src_tokens
    max_len = src_tokens.size(1)
    buffered = torch.empty(0).long()
    if max_len > 0:
        torch.arange(max_len, out=buffered)
    range = buffered.type_as(src_tokens).expand_as(src_tokens)
    num_pads = pad_mask.long().sum(dim=1, keepdim=True)
    if right_to_left:
        index = torch.remainder(range - num_pads, max_len)
    else:
        index = torch.remainder(range + num_pads, max_len)
    return src_tokens.gather(1, index) 
示例13
def convert_padding_direction(src_tokens, padding_idx, right_to_left=False, left_to_right=False):
    assert right_to_left ^ left_to_right
    pad_mask = src_tokens.eq(padding_idx)
    if not pad_mask.any():
        # no padding, return early
        return src_tokens
    if left_to_right and not pad_mask[:, 0].any():
        # already right padded
        return src_tokens
    if right_to_left and not pad_mask[:, -1].any():
        # already left padded
        return src_tokens
    max_len = src_tokens.size(1)
    range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens)
    num_pads = pad_mask.long().sum(dim=1, keepdim=True)
    if right_to_left:
        index = torch.remainder(range - num_pads, max_len)
    else:
        index = torch.remainder(range + num_pads, max_len)
    return src_tokens.gather(1, index) 
示例14
def loss(self, outputs, labels, **_):
        if self.model.training:
            labels_flip = labels +  self.num_classes // 2
            labels_flip = torch.remainder(labels_flip, self.num_classes)
            if labels_flip.dim() == 1:
                labels_flip = labels_flip.unsqueeze(-1)
            onehot = torch.zeros(outputs.size()).cuda()
            onehot.scatter_(1, labels_flip, 1)
            onehot_invert = (onehot == 0).float()
            assert onehot_invert.size() == outputs.size()
            outputs = outputs * onehot_invert - onehot_invert
            return self.criterion(outputs, labels)
        return torch.FloatTensor([0]) 
示例15
def reflect_conj_concat(kern, dim):
    """Reflects and conjugates kern before concatenating along dim.

    Args:
        kern (tensor): One half of a full, Hermitian-symmetric kernel.
        dim (int): The integer across which to apply Hermitian symmetry.

    Returns:
        tensor: The full FFT kernel after Hermitian-symmetric reflection.
    """
    dtype, device = kern.dtype, kern.device
    dim = -1 - dim
    flipdims = tuple(torch.arange(abs(dim)) + dim)

    # calculate size of central z block
    zblockshape = torch.tensor(kern.shape)
    zblockshape[dim] = 1
    zblock = torch.zeros(*zblockshape, dtype=dtype, device=device)

    # conjugation array
    conj_arr = torch.tensor([1, -1], dtype=dtype, device=device)
    conj_arr = conj_arr.unsqueeze(0).unsqueeze(0)
    while conj_arr.ndim < kern.ndim:
        conj_arr = conj_arr.unsqueeze(-1)

    # reflect the original block and conjugate it
    tmpblock = conj_arr * kern
    for d in flipdims:
        tmpblock = tmpblock.index_select(
            d,
            torch.remainder(
                -1 * torch.arange(tmpblock.shape[d], device=device), tmpblock.shape[d])
        )
    tmpblock = torch.cat(
        (zblock, tmpblock.narrow(dim, 1, tmpblock.shape[dim]-1)), dim)

    # concatenate and return
    return torch.cat((kern, tmpblock), dim) 
示例16
def hermitify(kern, dim):
    """Enforce Hermitian symmetry.

    This function takes an approximately Hermitian-symmetric kernel and
    enforces Hermitian symmetry by calcualting a tensor that reverses the
    coordinates and conjugates the original, then averaging that tensor with
    the original.

    Args:
        kern (tensor): An approximately Hermitian-symmetric kernel.
        dim (int): The last imaging dimension.

    Returns:
        tensor: A Hermitian-symmetric kernel.
    """
    dtype, device = kern.dtype, kern.device
    dim = -1 - dim + kern.ndim

    start = kern.clone()

    # reverse coordinates for each dimension
    for d in range(dim, kern.ndim):
        kern = kern.index_select(
            d,
            torch.remainder(
                -1 * torch.arange(kern.shape[d], device=device), kern.shape[d])
        )

    # conjugate
    conj_arr = torch.tensor([1, -1], dtype=dtype, device=device)
    conj_arr = conj_arr.unsqueeze(0).unsqueeze(0)
    while conj_arr.ndim < kern.ndim:
        conj_arr = conj_arr.unsqueeze(-1)
    kern = conj_arr * kern

    # take the average
    kern = (start + kern) / 2

    return kern 
示例17
def mod(self, divisor):
        divisor = divisor.long()
        return torch.remainder(self.curr_buffer, divisor).int() 
示例18
def forward(ctx, h, z, buf, mask, slice_dim=0):
        buf.mul(z[:,slice_dim:], mask[:, slice_dim:])
        h_mod = torch.remainder(h[:,slice_dim:], z[:,slice_dim:])
        buf.add(h_mod, mask[:, slice_dim:])
        h[h<0] = mask[h<0]*(h[h<0]-(z[h<0]-1)) + (1-mask[h<0])*h[h<0]
        h = mask*(h / z) + (1-mask)*h 

        h = mask*(h * (2**forget_radix)) + (1-mask)*h
        buf_mod = buf.mod_divide(forget_radix, mask[:, slice_dim:])
        h[:,slice_dim:] = mask[:, slice_dim:]*(h[:,slice_dim:].__or__(buf_mod)) +\
            (1-mask[:, slice_dim:])*h[:,slice_dim:]

        return h 
示例19
def __init__(self, num_categories, embedding_dim, num_collisions,
                 operation='mult', max_norm=None, norm_type=2.,
                 scale_grad_by_freq=False, mode='mean', sparse=False,
                 _weight=None):
        super(QREmbeddingBag, self).__init__()

        assert operation in ['concat', 'mult', 'add'], 'Not valid operation!'

        self.num_categories = num_categories
        if isinstance(embedding_dim, int) or len(embedding_dim) == 1:
            self.embedding_dim = [embedding_dim, embedding_dim]
        else:
            self.embedding_dim = embedding_dim
        self.num_collisions = num_collisions
        self.operation = operation
        self.max_norm = max_norm
        self.norm_type = norm_type
        self.scale_grad_by_freq = scale_grad_by_freq

        if self.operation == 'add' or self.operation == 'mult':
            assert self.embedding_dim[0] == self.embedding_dim[1], \
                'Embedding dimensions do not match!'

        self.num_embeddings = [int(np.ceil(num_categories / num_collisions)),
            num_collisions]

        if _weight is None:
            self.weight_q = Parameter(torch.Tensor(self.num_embeddings[0], self.embedding_dim[0]))
            self.weight_r = Parameter(torch.Tensor(self.num_embeddings[1], self.embedding_dim[1]))
            self.reset_parameters()
        else:
            assert list(_weight[0].shape) == [self.num_embeddings[0], self.embedding_dim[0]], \
                'Shape of weight for quotient table does not match num_embeddings and embedding_dim'
            assert list(_weight[1].shape) == [self.num_embeddings[1], self.embedding_dim[1]], \
                'Shape of weight for remainder table does not match num_embeddings and embedding_dim'
            self.weight_q = Parameter(_weight[0])
            self.weight_r = Parameter(_weight[1])
        self.mode = mode
        self.sparse = sparse 
示例20
def __init__(self, num_categories, embedding_dim, num_collisions,
                 operation='mult', max_norm=None, norm_type=2.,
                 scale_grad_by_freq=False, mode='mean', sparse=False,
                 _weight=None):
        super(QREmbeddingBag, self).__init__()

        assert operation in ['concat', 'mult', 'add'], 'Not valid operation!'

        self.num_categories = num_categories
        if isinstance(embedding_dim, int) or len(embedding_dim) == 1:
            self.embedding_dim = [embedding_dim, embedding_dim]
        else:
            self.embedding_dim = embedding_dim
        self.num_collisions = num_collisions
        self.operation = operation
        self.max_norm = max_norm
        self.norm_type = norm_type
        self.scale_grad_by_freq = scale_grad_by_freq

        if self.operation == 'add' or self.operation == 'mult':
            assert self.embedding_dim[0] == self.embedding_dim[1], \
                'Embedding dimensions do not match!'

        self.num_embeddings = [int(np.ceil(num_categories / num_collisions)),
            num_collisions]

        if _weight is None:
            self.weight_q = Parameter(torch.Tensor(self.num_embeddings[0], self.embedding_dim[0]))
            self.weight_r = Parameter(torch.Tensor(self.num_embeddings[1], self.embedding_dim[1]))
            self.reset_parameters()
        else:
            assert list(_weight[0].shape) == [self.num_embeddings[0], self.embedding_dim[0]], \
                'Shape of weight for quotient table does not match num_embeddings and embedding_dim'
            assert list(_weight[1].shape) == [self.num_embeddings[1], self.embedding_dim[1]], \
                'Shape of weight for remainder table does not match num_embeddings and embedding_dim'
            self.weight_q = Parameter(_weight[0])
            self.weight_r = Parameter(_weight[1])
        self.mode = mode
        self.sparse = sparse 
示例21
def log_uniform_sample(N, size):
    log_N = math.log(N)
    x = torch.Tensor(size).uniform_(0, 1)
    value = torch.exp(x * log_N).long() - 1
    return torch.remainder(value, N) 
示例22
def log_uniform_sample(N, size):
    log_N = math.log(N)
    x = torch.Tensor(size).uniform_(0, 1)
    value = torch.exp(x * log_N).long() - 1
    return torch.remainder(value, N) 
示例23
def fmod(t1, t2):
    """
    Element-wise division remainder of values of operand t1 by values of operand t2 (i.e. C Library function fmod), not commutative.
    Takes the two operands (scalar or tensor, both may contain floating point number) whose elements are to be
    divided (operand 1 by operand 2) as arguments.

    Parameters
    ----------
    t1: tensor or scalar
        The first operand whose values are divided (may be floats)
    t2: tensor or scalar
        The second operand by whose values is divided (may be floats)

    Returns
    -------
    result: ht.DNDarray
        A tensor containing the remainder of the element-wise division (i.e. floating point values) of t1 by t2.
        It has the sign as the dividend t1.

    Examples:
    ---------
    >>> import heat as ht
    >>> ht.fmod(2.0, 2.0)
    tensor([0.])

    >>> T1 = ht.float32([[1, 2], [3, 4]])
    >>> T2 = ht.float32([[2, 2], [2, 2]])
    >>> ht.fmod(T1, T2)
    tensor([[1., 0.],
            [1., 0.]])

    >>> s = 2.0
    >>> ht.fmod(s, T1)
    tensor([[0., 0.]
            [2., 2.]])
    """
    return operations.__binary_op(torch.fmod, t1, t2) 
示例24
def mod(t1, t2):
    """
    Element-wise division remainder of values of operand t1 by values of operand t2 (i.e. t1 % t2), not commutative.
    Takes the two operands (scalar or tensor) whose elements are to be divided (operand 1 by operand 2) as arguments.

    Currently t1 and t2 are just passed to remainder.

    Parameters
    ----------
    t1: tensor or scalar
        The first operand whose values are divided
    t2: tensor or scalar
        The second operand by whose values is divided

    Returns
    -------
    result: ht.DNDarray
        A tensor containing the remainder of the element-wise division of t1 by t2.
        It has the same sign as the devisor t2.

    Examples:
    ---------
    >>> import heat as ht
    >>> ht.mod(2, 2)
    tensor([0])

    >>> T1 = ht.int32([[1, 2], [3, 4]])
    >>> T2 = ht.int32([[2, 2], [2, 2]])
    >>> ht.mod(T1, T2)
    tensor([[1, 0],
            [1, 0]], dtype=torch.int32)

    >>> s = 2
    >>> ht.mod(s, T1)
    tensor([[0, 0]
            [2, 2]], dtype=torch.int32)
    """
    return remainder(t1, t2) 
示例25
def remainder(t1, t2):
    """
    Element-wise division remainder of values of operand t1 by values of operand t2 (i.e. t1 % t2), not commutative.
    Takes the two operands (scalar or tensor) whose elements are to be divided (operand 1 by operand 2) as arguments.

    Parameters
    ----------
    t1: tensor or scalar
        The first operand whose values are divided
    t2: tensor or scalar
        The second operand by whose values is divided

    Returns
    -------
    result: ht.DNDarray
        A tensor containing the remainder of the element-wise division of t1 by t2.
        It has the same sign as the devisor t2.

    Examples:
    ---------
    >>> import heat as ht
    >>> ht.mod(2, 2)
    tensor([0])

    >>> T1 = ht.int32([[1, 2], [3, 4]])
    >>> T2 = ht.int32([[2, 2], [2, 2]])
    >>> ht.mod(T1, T2)
    tensor([[1, 0],
            [1, 0]], dtype=torch.int32)

    >>> s = 2
    >>> ht.mod(s, T1)
    tensor([[0, 0]
            [2, 2]], dtype=torch.int32)
    """
    return operations.__binary_op(torch.remainder, t1, t2) 
示例26
def log_uniform_sample(N, size):
    log_N = math.log(N)
    x = torch.Tensor(size).uniform_(0, 1)
    value = torch.exp(x * log_N).long() - 1
    return torch.remainder(value, N) 
示例27
def log_uniform_sample(N, size):
    log_N = math.log(N)
    x = torch.Tensor(size).uniform_(0, 1)
    value = torch.exp(x * log_N).long() - 1
    return torch.remainder(value, N) 
示例28
def log_uniform_sample(N, size):
    log_N = math.log(N)
    x = torch.Tensor(size).uniform_(0, 1)
    value = torch.exp(x * log_N).long() - 1
    return torch.remainder(value, N) 
示例29
def calc_coef_and_indices(tm, kofflist, Jval, table, centers, L, dims, conjcoef=False):
    """Calculates interpolation coefficients and on-grid indices.

    Args:
        tm (tensor): normalized frequency locations.
        kofflist (tensor): A tensor with offset locations to first elements in
            list of nearest neighbords.
        Jval (tensor): A tuple-like tensor for how much to increment offsets.
        table (list): A list of tensors tabulating a Kaiser-Bessel
            interpolation kernel.
        centers (tensor): A tensor with the center locations of the table for
            each dimension.
        L (tensor): A tensor with the table size in each dimension.
        dims (tensor): A tensor with image dimensions.
        conjcoef (boolean, default=False): A boolean for whether to compute
            normal or complex conjugate interpolation coefficients
            (conjugate needed for adjoint).

    Returns:
        tuple: A tuple with interpolation coefficients and indices.
    """
    # type values
    dtype = tm.dtype
    device = tm.device
    int_type = torch.long

    # array shapes
    M = tm.shape[1]
    ndims = tm.shape[0]

    # indexing locations
    gridind = (kofflist + Jval.unsqueeze(1)).to(dtype)
    distind = torch.round(
        (tm - gridind) * L.unsqueeze(1)).to(dtype=int_type)
    gridind = gridind.to(int_type)

    arr_ind = torch.zeros((M,), dtype=int_type, device=device)
    coef = torch.stack((
        torch.ones(M, dtype=dtype, device=device),
        torch.zeros(M, dtype=dtype, device=device)
    ))

    for d in range(ndims):  # spatial dimension
        if conjcoef:
            coef = conj_complex_mult(
                coef,
                table[d][:, distind[d, :] + centers[d]],
                dim=0
            )
        else:
            coef = complex_mult(
                coef,
                table[d][:, distind[d, :] + centers[d]],
                dim=0
            )
        arr_ind = arr_ind + torch.remainder(gridind[d, :], dims[d]).view(-1) * \
            torch.prod(dims[d + 1:])

    return coef, arr_ind 
示例30
def _val_epoch(self, epoch):
        self.model.eval()

        val_std_loss = Metric('val_std_loss')
        val_std_acc = Metric('val_std_acc')

        val_adv_acc = Metric('val_adv_acc')
        val_adv_loss = Metric('val_adv_loss')
        val_max_adv_acc = Metric('val_max_adv_acc')
        val_max_adv_loss = Metric('val_max_adv_loss')

        for batch_idx, (data, target) in enumerate(self.val_loader):
            if self.cuda:
                data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
            with torch.no_grad():
                output = self.model(data)
                val_std_loss.update(F.cross_entropy(output, target))
                val_std_acc.update(accuracy(output, target))
            if self.attack:
                rand_target = torch.randint(
                    0, len(self.val_dataset.classes) - 1, target.size(),
                    dtype=target.dtype, device='cuda')
                rand_target = torch.remainder(target + rand_target + 1, len(self.val_dataset.classes))
                data_adv = self.attack(self.model, data, rand_target,
                                       avoid_target=False, scale_eps=self.scale_eps)
                data_max_adv = self.attack(self.model, data, rand_target, avoid_target=False, scale_eps=False)
                with torch.no_grad():
                    output_adv = self.model(data_adv)
                    val_adv_loss.update(F.cross_entropy(output_adv, target))
                    val_adv_acc.update(accuracy(output_adv, target))
                    
                    output_max_adv = self.model(data_max_adv)
                    val_max_adv_loss.update(F.cross_entropy(output_max_adv, target))
                    val_max_adv_acc.update(accuracy(output_max_adv, target))
            self.model.eval()

        if hvd.rank() == 0:
            log_dict = {'val_std_loss':val_std_loss.avg.item(),
                        'val_std_acc':val_std_acc.avg.item(),
                        'val_adv_loss':val_adv_loss.avg.item(),
                        'val_adv_acc':val_adv_acc.avg.item(),
                        'val_adv_loss':val_max_adv_loss.avg.item(),
                        'val_max_adv_acc':val_max_adv_acc.avg.item()}
            self.logger.log(log_dict, epoch)

        if self.verbose:
            print(log_dict)

        self.optimizer.synchronize()
        self.optimizer.zero_grad()