Python源码示例:torch.log_softmax()

示例1
def forward(self, x):
        # 先计算得到线性的那一部分
        linear_part = self.linear(x)

        # 计算交叉部分
        interaction_part = 0.0
        for i in range(self.fea_num):
            for j in range(i + 1, self.fea_num):
                v_ifj = self.v[i, self.field_map_dict[j], :, :]
                v_jfi = self.v[j, self.field_map_dict[i], :, :]

                xij = torch.unsqueeze(x[:, i] * x[:, j], dim=1)
                v_ijji = torch.unsqueeze(torch.sum(v_ifj * v_jfi, dim=0), dim=0)

                interaction_part += torch.mm(xij, v_ijji)

        output = linear_part + interaction_part
        output = torch.log_softmax(output, dim=1)
        return output 
示例2
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args):
    epoch_loss = 0.0
    for image, target, input_len, target_len in tqdm(data_loader):
        image = image.to(device)
        # print(target, target_len, input_len)
        outputs = model(image.to(torch.float32))  # [B,N,C]
        outputs = torch.log_softmax(outputs, dim=2)
        outputs = outputs.permute([1, 0, 2])  # [N,B,C]
        loss = criterion(outputs[:], target, input_len, target_len)
        # 梯度更新
        model.zero_grad()
        loss.backward()
        optimizer.step()
        # 当前轮的loss
        epoch_loss += loss.item() * image.size(0)
        if np.isnan(loss.item()):
            print(target, input_len, target_len)

    epoch_loss = epoch_loss / len(data_loader.dataset)
    # 打印日志,保存权重
    print('Epoch: {}/{} loss: {:03f}'.format(epoch + 1, args.epochs, epoch_loss))
    return epoch_loss 
示例3
def forward(self, task_id, x, y, seq_len):
        words_emb = self.embedding(x)
        char_emb = self.char(x)
        x = torch.cat([words_emb, char_emb], dim=-1)
        x, _ = self.lstm(x, seq_len)
        self.dropout(x)
        logit = self.out[task_id[0]](x)

        seq_mask = seq_len_to_mask(seq_len, x.size(1))
        if self.crf is not None:
            logit = torch.log_softmax(logit, dim=-1)
            loss = self.crf[task_id[0]](logit, y, seq_mask).mean()
            pred = self.crf[task_id[0]].viterbi_decode(logit, seq_mask)[0]
        else:
            loss = ce_loss(logit, y, seq_mask)
            pred = torch.argmax(logit, dim=2)
        return {"loss": loss, "pred": pred} 
示例4
def distillation(logits_student, logits_teacher, ylens, temperature=5.0):
    """Compute cross entropy loss for knowledge distillation of sequence-to-sequence models.

    Args:
        logits_student (FloatTensor): `[B, T, vocab]`
        logits_teacher (FloatTensor): `[B, T, vocab]`
        ylens (IntTensor): `[B]`
        temperature (float):
    Returns:
        loss_mean (FloatTensor): `[1]`

    """
    bs, _, vocab = logits_student.size()

    log_probs_student = torch.log_softmax(logits_student, dim=-1)
    probs_teacher = torch.softmax(logits_teacher / temperature, dim=-1).data
    loss = -torch.mul(probs_teacher, log_probs_student)
    loss_mean = np.sum([loss[b, :ylens[b], :].sum() for b in range(bs)]) / ylens.sum()
    return loss_mean 
示例5
def kldiv_lsm_ctc(logits, ylens):
    """Compute KL divergence loss for label smoothing of CTC and Transducer models.

    Args:
        logits (FloatTensor): `[B, T, vocab]`
        ylens (IntTensor): `[B]`
    Returns:
        loss_mean (FloatTensor): `[1]`

    """
    bs, _, vocab = logits.size()

    log_uniform = logits.new_zeros(logits.size()).fill_(math.log(1 / (vocab - 1)))
    probs = torch.softmax(logits, dim=-1)
    log_probs = torch.log_softmax(logits, dim=-1)
    loss = torch.mul(probs, log_probs - log_uniform)
    loss_mean = np.sum([loss[b, :ylens[b], :].sum() for b in range(bs)]) / ylens.sum()
    # assert loss_mean >= 0
    return loss_mean 
示例6
def focal_loss(logits, ys, ylens, alpha, gamma):
    """Compute focal loss.

    Args:
        logits (FloatTensor): `[B, T, vocab]`
        ys (LongTensor): Indices of labels. `[B, L]`
        ylens (IntTensor): `[B]`
        alpha (float):
        gamma (float):
    Returns:
        loss_mean (FloatTensor): `[1]`

    """
    bs = ys.size(0)

    log_probs = torch.log_softmax(logits, dim=-1)
    probs_inv = -torch.softmax(logits, dim=-1) + 1
    loss = -alpha * torch.mul(torch.pow(probs_inv, gamma), log_probs)
    loss_mean = np.sum([loss[b, :ylens[b], :].sum() for b in range(bs)]) / ylens.sum()
    return loss_mean 
示例7
def greedy(self, eouts, elens):
        """Greedy decoding.

        Args:
            eouts (FloatTensor): `[B, T, enc_n_units]`
            elens (np.ndarray): `[B]`
        Returns:
            hyps (np.ndarray): Best path hypothesis. `[B, L]`

        """
        log_probs = torch.log_softmax(self.output(eouts), dim=-1)
        best_paths = log_probs.argmax(-1)  # `[B, L]`

        hyps = []
        for b in range(eouts.size(0)):
            indices = [best_paths[b, t].item() for t in range(elens[b])]

            # Step 1. Collapse repeated labels
            collapsed_indices = [x[0] for x in groupby(indices)]

            # Step 2. Remove all blank labels
            best_hyp = [x for x in filter(lambda x: x != self.blank, collapsed_indices)]
            hyps.append(np.array(best_hyp))

        return np.array(hyps) 
示例8
def test_log_softmax():
    src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')])
    src.requires_grad_()
    index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4])

    out = scatter_log_softmax(src, index)

    out0 = torch.log_softmax(torch.tensor([0.2, 0.2]), dim=-1)
    out1 = torch.log_softmax(torch.tensor([0, -2.1, 3.2]), dim=-1)
    out2 = torch.log_softmax(torch.tensor([7], dtype=torch.float), dim=-1)
    out4 = torch.log_softmax(torch.tensor([-1, float('-inf')]), dim=-1)

    expected = torch.stack([
        out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
    ], dim=0)

    assert torch.allclose(out, expected)

    out.backward(torch.randn_like(out)) 
示例9
def forward(self, x, target):
        """Compute loss between x and target.

        :param torch.Tensor x: prediction (batch, seqlen, class)
        :param torch.Tensor target:
            target signal masked with self.padding_id (batch, seqlen)
        :return: scalar float value
        :rtype torch.Tensor
        """
        assert x.size(2) == self.size
        batch_size = x.size(0)
        x = x.view(-1, self.size)
        target = target.view(-1)
        with torch.no_grad():
            true_dist = x.clone()
            true_dist.fill_(self.smoothing / (self.size - 1))
            ignore = target == self.padding_idx  # (B,)
            total = len(target) - ignore.sum().item()
            target = target.masked_fill(ignore, 0)  # avoid -1 index
            true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
        kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
        denom = total if self.normalize_length else batch_size
        return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom 
示例10
def forward(self, x, y, get_scores=False):
        """
        Compute the loss, and optionally the scores.
        """
        assert (y == self.pad_index).sum().item() == 0

        if self.asm is False:
            scores = self.proj(x).view(-1, self.n_words)
            if self.label_smoothing == 0.0:
                loss = F.cross_entropy(scores, y, reduction='elementwise_mean')
            else:
                lprobs = torch.log_softmax(scores, dim=1)
                nll_loss = -lprobs.gather(dim=-1, index=y.unsqueeze(1))
                smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
                nll_loss, smooth_loss = nll_loss.sum(), smooth_loss.sum()
                eps_i = self.label_smoothing / lprobs.size(-1)
                loss = (1. - self.label_smoothing) * nll_loss + eps_i * smooth_loss
                loss = loss / x.shape[0]
        else:
            _, loss = self.proj(x, y)
            scores = self.proj.log_prob(x) if get_scores else None

        return scores, loss 
示例11
def init_step(self, beam, expected_len_pen):
        # init_preds: [4, 3, 5, 6, 7] - no EOS's
        init_scores = torch.log_softmax(torch.tensor(
            [[0, 0, 0, 4, 5, 3, 2, 1]], dtype=torch.float), dim=1)
        init_scores = deepcopy(init_scores.repeat(
            self.BATCH_SZ * self.BEAM_SZ, 1))
        new_scores = init_scores + beam.topk_log_probs.view(-1).unsqueeze(1)
        expected_beam_scores, expected_preds_0 = new_scores \
            .view(self.BATCH_SZ, self.BEAM_SZ * self.N_WORDS) \
            .topk(self.BEAM_SZ, dim=-1)
        beam.advance(deepcopy(init_scores), self.random_attn())
        self.assertTrue(beam.topk_log_probs.allclose(expected_beam_scores))
        self.assertTrue(beam.topk_ids.equal(expected_preds_0))
        self.assertFalse(beam.is_finished.any())
        self.assertFalse(beam.done)
        return expected_beam_scores 
示例12
def forward(self, x, target):
        """Compute loss between x and target.

        :param torch.Tensor x: prediction (batch, seqlen, class)
        :param torch.Tensor target: target signal masked with self.padding_id (batch, seqlen)
        :return: scalar float value
        :rtype torch.Tensor
        """
        assert x.size(2) == self.size
        batch_size = x.size(0)
        x = x.view(-1, self.size)
        target = target.view(-1)
        with torch.no_grad():
            true_dist = x.clone()
            true_dist.fill_(self.smoothing / (self.size - 1))
            ignore = target == self.padding_idx  # (B,)
            total = len(target) - ignore.sum().item()
            target = target.masked_fill(ignore, 0)  # avoid -1 index
            true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
        kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
        denom = total if self.normalize_length else batch_size
        return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom 
示例13
def forward(ctx, logits, label, lb_smooth, lb_ignore):
        # prepare label
        num_classes = logits.size(1)
        lb_pos, lb_neg = 1. - lb_smooth, lb_smooth / num_classes
        label = label.clone().detach()
        ignore = label == lb_ignore
        n_valid = (label != lb_ignore).sum()
        label[ignore] = 0
        lb_one_hot = torch.empty_like(logits).fill_(
            lb_neg).scatter_(1, label.unsqueeze(1), lb_pos).detach()

        ignore = ignore.nonzero()
        _, M = ignore.size()
        a, *b = ignore.chunk(M, dim=1)
        mask = [a, torch.arange(logits.size(1)), *b]
        lb_one_hot[mask] = 0
        coeff = (num_classes - 1) * lb_neg + lb_pos

        ctx.variables = coeff, mask, logits, lb_one_hot

        loss = torch.log_softmax(logits, dim=1).neg_().mul_(lb_one_hot).sum(dim=1)
        return loss 
示例14
def init_step(self, beam, expected_len_pen):
        # init_preds: [4, 3, 5, 6, 7] - no EOS's
        init_scores = torch.log_softmax(torch.tensor(
            [[0, 0, 0, 4, 5, 3, 2, 1]], dtype=torch.float), dim=1)
        init_scores = deepcopy(init_scores.repeat(
            self.BATCH_SZ * self.BEAM_SZ, 1))
        new_scores = init_scores + beam.topk_log_probs.view(-1).unsqueeze(1)
        expected_beam_scores, expected_preds_0 = new_scores \
            .view(self.BATCH_SZ, self.BEAM_SZ * self.N_WORDS) \
            .topk(self.BEAM_SZ, dim=-1)
        beam.advance(deepcopy(init_scores), self.random_attn())
        self.assertTrue(beam.topk_log_probs.allclose(expected_beam_scores))
        self.assertTrue(beam.topk_ids.equal(expected_preds_0))
        self.assertFalse(beam.is_finished.any())
        self.assertFalse(beam.done)
        return expected_beam_scores 
示例15
def init_step(self, beam, expected_len_pen):
        # init_preds: [4, 3, 5, 6, 7] - no EOS's
        init_scores = torch.log_softmax(torch.tensor(
            [[0, 0, 0, 4, 5, 3, 2, 1]], dtype=torch.float), dim=1)
        init_scores = deepcopy(init_scores.repeat(
            self.BATCH_SZ * self.BEAM_SZ, 1))
        new_scores = init_scores + beam.topk_log_probs.view(-1).unsqueeze(1)
        expected_beam_scores, expected_preds_0 = new_scores \
            .view(self.BATCH_SZ, self.BEAM_SZ * self.N_WORDS) \
            .topk(self.BEAM_SZ, dim=-1)
        beam.advance(deepcopy(init_scores), self.random_attn())
        self.assertTrue(beam.topk_log_probs.allclose(expected_beam_scores))
        self.assertTrue(beam.topk_ids.equal(expected_preds_0))
        self.assertFalse(beam.is_finished.any())
        self.assertFalse(beam.done)
        return expected_beam_scores 
示例16
def forward(self, x, target):
        """Compute loss between x and target

        :param torch.Tensor x: prediction (batch, seqlen, class)
        :param torch.Tensor target: target signal masked with self.padding_id (batch, seqlen)
        :return: scalar float value
        :rtype torch.Tensor
        """
        assert x.size(2) == self.size
        batch_size = x.size(0)
        x = x.view(-1, self.size)
        target = target.reshape(-1)
        with torch.no_grad():
            true_dist = x.clone()
            true_dist.fill_(self.smoothing / (self.size - 1))
            ignore = target == self.padding_idx  # (B,)
            total = len(target) - ignore.sum().item()
            target = target.masked_fill(ignore, 0)  # avoid -1 index
            true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
        kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
        denom = total if self.normalize_length else batch_size
        return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom 
示例17
def discriminate(self, z, edge_index):
        """Given node embeddings :obj:`z`, classifies the link relation
        between node pairs :obj:`edge_index` to be either positive,
        negative or non-existent.

        Args:
            x (Tensor): The input node features.
            edge_index (LongTensor): The edge indices.
        """
        value = torch.cat([z[edge_index[0]], z[edge_index[1]]], dim=1)
        value = self.lin(value)
        return torch.log_softmax(value, dim=1) 
示例18
def forward(self, x, edge_index):
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x_all = x.view(-1, 1, self.hidden_channels)
        for conv in self.convs:
            x = F.relu(conv(x_all, edge_index))
            x = x.view(-1, 1, self.hidden_channels)
            x_all = torch.cat([x_all, x], dim=1)
        x = x_all[:, -1]
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return torch.log_softmax(x, dim=1) 
示例19
def forward(self, features):
        features = self.fc(features)
        return torch.log_softmax(features, dim=-1) 
示例20
def forward(self, x):
        return th.log_softmax(
            self.proj(x), dim=-1
        ) 
示例21
def train_step(protonet, datax, datay, Ns, Nc, Nq):
    optimizer.zero_grad()
    Qx, Qy = protonet(datax, datay, Ns, Nc, Nq, np.unique(datay))
    pred = torch.log_softmax(Qx, dim=-1)
    loss = F.nll_loss(pred, Qy)
    loss.backward()
    optimizer.step()
    acc = torch.mean((torch.argmax(pred, 1) == Qy).float())
    return loss, acc 
示例22
def test_step(protonet, datax, datay, Ns, Nc, Nq):
    Qx, Qy = protonet(datax, datay, Ns, Nc, Nq, np.unique(datay))
    pred = torch.log_softmax(Qx, dim=-1)
    loss = F.nll_loss(pred, Qy)
    acc = torch.mean((torch.argmax(pred, 1) == Qy).float())
    return loss, acc 
示例23
def cross_entropy_lsm(logits, ys, lsm_prob, ignore_index, training, normalize_length=False):
    """Compute cross entropy loss for label smoothing of sequence-to-sequence models.

    Args:
        logits (FloatTensor): `[B, T, vocab]`
        ys (LongTensor): Indices of labels. `[B, L]`
        lsm_prob (float): label smoothing probability
        ignore_index (int): index for padding
        normalize_length (bool): normalize XE loss by target sequence length
    Returns:
        loss_mean (FloatTensor): `[1]`
        ppl (float): perplexity

    """
    bs, _, vocab = logits.size()
    ys = ys.view(-1)
    logits = logits.view((-1, logits.size(2)))

    if lsm_prob == 0 or not training:
        loss = F.cross_entropy(logits, ys,
                               ignore_index=ignore_index, reduction='mean')
        ppl = np.exp(loss.item())
        if not normalize_length:
            loss *= (ys != ignore_index).sum() / bs
    else:
        with torch.no_grad():
            target_dist = logits.new_zeros(logits.size())
            target_dist.fill_(lsm_prob / (vocab - 1))
            mask = (ys == ignore_index)
            ys_masked = ys.masked_fill(mask, 0)
            target_dist.scatter_(1, ys_masked.unsqueeze(1), 1 - lsm_prob)

        log_probs = torch.log_softmax(logits, dim=-1)
        loss_sum = -torch.mul(target_dist, log_probs)
        n_tokens = len(ys) - mask.sum().item()
        denom = n_tokens if normalize_length else bs
        loss = loss_sum.masked_fill(mask.unsqueeze(1), 0).sum() / denom

        ppl = np.exp(loss.item()) if normalize_length else np.exp(loss.item() * bs / n_tokens)

    return loss, ppl 
示例24
def predict(self, ys, state=None, mems=None, cache=None):
        """Precict function for ASR.

        Args:
            ys (LongTensor): `[B, L]`
            state:
                - RNNLM: dict
                    hxs (FloatTensor): `[n_layers, B, n_units]`
                    cxs (FloatTensor): `[n_layers, B, n_units]`
                - TransformerLM (LongTensor): `[B, L]`
                - TransformerXL (list): length `n_layers + 1`, each of which contains a tensor`[B, L, d_model]`
            mems (list):
            cache (list):
        Returns:
            lmout (FloatTensor): `[B, L, vocab]`, used for LM integration such as cold fusion
            state:
                - RNNLM: dict
                    hxs (FloatTensor): `[n_layers, B, n_units]`
                    cxs (FloatTensor): `[n_layers, B, n_units]`
                - TransformerLM (LongTensor): `[B, L]`
                - TransformerXL (list): length `n_layers + 1`, each of which contains a tensor`[B, L, d_model]`
            log_probs (FloatTensor): `[B, L, vocab]`

        """
        logits, lmout, new_state = self.decode(ys, state, mems=mems, cache=cache,
                                               incremental=True)
        log_probs = torch.log_softmax(logits, dim=-1)
        return lmout, new_state, log_probs 
示例25
def recognize(self, h, recog_args):
        """Greedy search implementation for transformer-transducer.

        Args:
            h (torch.Tensor): encoder hidden state sequences (maxlen_in, Henc)
            recog_args (Namespace): argument Namespace containing options

        Returns:
            hyp (list of dicts): 1-best decoding results

        """
        hyp = {"score": 0.0, "yseq": [self.blank]}

        ys = to_device(self, torch.tensor(hyp["yseq"], dtype=torch.long)).unsqueeze(0)
        ys_mask = to_device(self, subsequent_mask(1).unsqueeze(0))
        y, c = self.forward_one_step(ys, ys_mask, None)

        for i, hi in enumerate(h):
            ytu = torch.log_softmax(self.joint(hi, y[0]), dim=0)
            logp, pred = torch.max(ytu, dim=0)

            if pred != self.blank:
                hyp["yseq"].append(int(pred))
                hyp["score"] += float(logp)

                ys = to_device(self, torch.tensor(hyp["yseq"]).unsqueeze(0))
                ys_mask = to_device(
                    self, subsequent_mask(len(hyp["yseq"])).unsqueeze(0)
                )

                y, c = self.forward_one_step(ys, ys_mask, c)

        return [hyp] 
示例26
def forward_one_step(self, tgt, tgt_mask, memory, cache=None):
        """Forward one step.

        :param torch.Tensor tgt: input token ids, int64 (batch, maxlen_out)
        :param torch.Tensor tgt_mask: input token mask,  (batch, maxlen_out)
                                      dtype=torch.uint8 in PyTorch 1.2-
                                      dtype=torch.bool in PyTorch 1.2+ (include 1.2)
        :param torch.Tensor memory: encoded memory, float32  (batch, maxlen_in, feat)
        :param List[torch.Tensor] cache:
            cached output list of (batch, max_time_out-1, size)
        :return y, cache: NN output value and cache per `self.decoders`.
            `y.shape` is (batch, maxlen_out, token)
        :rtype: Tuple[torch.Tensor, List[torch.Tensor]]
        """
        x = self.embed(tgt)
        if cache is None:
            cache = [None] * len(self.decoders)
        new_cache = []
        for c, decoder in zip(cache, self.decoders):
            x, tgt_mask, memory, memory_mask = decoder(
                x, tgt_mask, memory, None, cache=c
            )
            new_cache.append(x)

        if self.normalize_before:
            y = self.after_norm(x[:, -1])
        else:
            y = x[:, -1]
        if self.output_layer is not None:
            y = torch.log_softmax(self.output_layer(y), dim=-1)

        return y, new_cache

    # beam search API (see ScorerInterface) 
示例27
def forward_one_step(
        self,
        tgt: torch.Tensor,
        tgt_mask: torch.Tensor,
        memory: torch.Tensor,
        cache: List[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        """Forward one step.

        Args:
            tgt: input token ids, int64 (batch, maxlen_out)
            tgt_mask: input token mask,  (batch, maxlen_out)
                      dtype=torch.uint8 in PyTorch 1.2-
                      dtype=torch.bool in PyTorch 1.2+ (include 1.2)
            memory: encoded memory, float32  (batch, maxlen_in, feat)
            cache: cached output list of (batch, max_time_out-1, size)
        Returns:
            y, cache: NN output value and cache per `self.decoders`.
            y.shape` is (batch, maxlen_out, token)
        """
        x = self.embed(tgt)
        if cache is None:
            cache = self.init_state()
        new_cache = []
        for c, decoder in zip(cache, self.decoders):
            x, tgt_mask, memory, memory_mask = decoder(
                x, tgt_mask, memory, None, cache=c
            )
            new_cache.append(x)

        if self.normalize_before:
            y = self.after_norm(x[:, -1])
        else:
            y = x[:, -1]
        if self.output_layer is not None:
            y = torch.log_softmax(self.output_layer(y), dim=-1)

        return y, new_cache

    # beam search API (see ScorerInterface) 
示例28
def forward(self, x):
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)

        x = self.layer_1(x)
        x = torch.relu(x)

        x = self.layer_2(x)
        x = torch.relu(x)

        x = self.layer_3(x)
        x = torch.log_softmax(x, dim=1)

        return x 
示例29
def test_doesnt_predict_eos_if_shorter_than_min_len(self):
        # batch 0 will always predict EOS. The other batches will predict
        # non-eos scores.
        for batch_sz in [1, 3]:
            n_words = 100
            _non_eos_idxs = [47]
            valid_score_dist = torch.log_softmax(torch.tensor(
                [6., 5.]), dim=0)
            min_length = 5
            eos_idx = 2
            lengths = torch.randint(0, 30, (batch_sz,))
            samp = RandomSampling(
                0, 1, 2, batch_sz, torch.device("cpu"), min_length,
                False, set(), False, 30, 1., 1, lengths)
            all_attns = []
            for i in range(min_length + 4):
                word_probs = torch.full(
                    (batch_sz, n_words), -float('inf'))
                # "best" prediction is eos - that should be blocked
                word_probs[0, eos_idx] = valid_score_dist[0]
                # include at least one prediction OTHER than EOS
                # that is greater than -1e20
                word_probs[0, _non_eos_idxs[0]] = valid_score_dist[1]
                word_probs[1:, _non_eos_idxs[0] + i] = 0

                attns = torch.randn(1, batch_sz, 53)
                all_attns.append(attns)
                samp.advance(word_probs, attns)
                if i < min_length:
                    self.assertTrue(
                        samp.topk_scores[0].allclose(valid_score_dist[1]))
                    self.assertTrue(
                        samp.topk_scores[1:].eq(0).all())
                elif i == min_length:
                    # now batch 0 has ended and no others have
                    self.assertTrue(samp.is_finished[0, :].eq(1).all())
                    self.assertTrue(samp.is_finished[1:, 1:].eq(0).all())
                else:  # i > min_length
                    break 
示例30
def first_step(self, beam, expected_beam_scores, expected_len_pen):
        # no EOS's yet
        assert beam.is_finished.sum() == 0
        scores_1 = torch.log_softmax(torch.tensor(
            [[0, 0,  0, .3,   0, .51, .2, 0],
             [0, 0, 1.5,  0,   0,   0,  0, 0],
             [0, 0,  0,  0, .49, .48,  0, 0],
             [0, 0, 0, .2, .2, .2, .2, .2],
             [0, 0, 0, .2, .2, .2, .2, .2]]
        ), dim=1)
        scores_1 = scores_1.repeat(self.BATCH_SZ, 1)

        beam.advance(deepcopy(scores_1), self.random_attn())

        new_scores = scores_1 + expected_beam_scores.view(-1).unsqueeze(1)
        expected_beam_scores, unreduced_preds = new_scores\
            .view(self.BATCH_SZ, self.BEAM_SZ * self.N_WORDS)\
            .topk(self.BEAM_SZ, -1)
        expected_bptr_1 = unreduced_preds / self.N_WORDS
        # [5, 3, 2, 6, 0], so beam 2 predicts EOS!
        expected_preds_1 = unreduced_preds - expected_bptr_1 * self.N_WORDS
        self.assertTrue(beam.topk_log_probs.allclose(expected_beam_scores))
        self.assertTrue(beam.topk_scores.allclose(
            expected_beam_scores / expected_len_pen))
        self.assertTrue(beam.topk_ids.equal(expected_preds_1))
        self.assertTrue(beam.current_backptr.equal(expected_bptr_1))
        self.assertEqual(beam.is_finished.sum(), self.BATCH_SZ)
        self.assertTrue(beam.is_finished[:, 2].all())  # beam 2 finished
        beam.update_finished()
        self.assertFalse(beam.top_beam_finished.any())
        self.assertFalse(beam.done)
        return expected_beam_scores