Python源码示例:torch.trace()

示例1
def get_loss(pred, y, criterion, mtr, a=0.5):
    """
    To calculate loss
    :param pred: predicted value
    :param y: actual value
    :param criterion: nn.CrossEntropyLoss
    :param mtr: beta matrix
    """
    mtr_t = torch.transpose(mtr, 1, 2)
    aa = torch.bmm(mtr, mtr_t)
    loss_fn = 0
    for i in range(aa.size()[0]):
        aai = torch.add(aa[i, ], Variable(torch.neg(torch.eye(mtr.size()[1]))))
        loss_fn += torch.trace(torch.mul(aai, aai).data)
    loss_fn /= aa.size()[0]
    loss = torch.add(criterion(pred, y), Variable(torch.FloatTensor([loss_fn * a])))
    return loss 
示例2
def feature_smoothing(self, adj, X):
        adj = (adj.t() + adj)/2
        rowsum = adj.sum(1)
        r_inv = rowsum.flatten()
        D = torch.diag(r_inv)
        L = D - adj

        r_inv = r_inv  + 1e-3
        r_inv = r_inv.pow(-1/2).flatten()
        r_inv[torch.isinf(r_inv)] = 0.
        r_mat_inv = torch.diag(r_inv)
        # L = r_mat_inv @ L
        L = r_mat_inv @ L @ r_mat_inv

        XLXT = torch.matmul(torch.matmul(X.t(), L), X)
        loss_smooth_feat = torch.trace(XLXT)
        return loss_smooth_feat 
示例3
def batch_mat2angle(R):
    """ Calcuate the axis angles (twist) from a batch of rotation matrices

        Ethan Eade's lie group note:
        http://ethaneade.com/lie.pdf equation (17)

        function tested in 'test_geometry.py'

    :input
    :param Rotation matrix Bx3x3 \in \SO3 space
    --------
    :return 
    :param the axis angle B
    """
    R1 = [torch.trace(R[i]) for i in range(R.size()[0])]
    R_trace = torch.stack(R1)
    # clamp if the angle is too large (break small angle assumption)
    # @todo: not sure whether it is absoluately necessary in training. 
    angle = acos( ((R_trace - 1)/2).clamp(-1,1))
    return angle 
示例4
def evaluate(X_data, name='Eval'):
    model.eval()
    eval_idx_list = np.arange(len(X_data), dtype="int32")
    total_loss = 0.0
    count = 0
    with torch.no_grad():
        for idx in eval_idx_list:
            data_line = X_data[idx]
            x, y = Variable(data_line[:-1]), Variable(data_line[1:])
            if args.cuda:
                x, y = x.cuda(), y.cuda()
            output = model(x.unsqueeze(0)).squeeze(0)
            loss = -torch.trace(torch.matmul(y, torch.log(output).float().t()) +
                                torch.matmul((1-y), torch.log(1-output).float().t()))
            total_loss += loss.item()
            count += output.size(0)
        eval_loss = total_loss / count
        print(name + " loss: {:.5f}".format(eval_loss))
        return eval_loss 
示例5
def btrace(X):
    # batch-trace: [B, N, N] -> [B]
    n = X.size(-1)
    X_ = X.view(-1, n, n)
    tr = torch.zeros(X_.size(0)).to(X)
    for i in range(tr.size(0)):
        m = X_[i, :, :]
        tr[i] = torch.trace(m)
    return tr.view(*(X.size()[0:-2])) 
示例6
def torch_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
  """Pytorch implementation of the Frechet Distance.
  Taken from https://github.com/bioinf-jku/TTUR
  The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
  and X_2 ~ N(mu_2, C_2) is
          d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
  Stable version by Dougal J. Sutherland.
  Params:
  -- mu1   : Numpy array containing the activations of a layer of the
             inception net (like returned by the function 'get_predictions')
             for generated samples.
  -- mu2   : The sample mean over activations, precalculated on an 
             representive data set.
  -- sigma1: The covariance matrix over activations for generated samples.
  -- sigma2: The covariance matrix over activations, precalculated on an 
             representive data set.
  Returns:
  --   : The Frechet Distance.
  """


  assert mu1.shape == mu2.shape, \
    'Training and test mean vectors have different lengths'
  assert sigma1.shape == sigma2.shape, \
    'Training and test covariances have different dimensions'

  diff = mu1 - mu2
  # Run 50 itrs of newton-schulz to get the matrix sqrt of sigma1 dot sigma2
  covmean = sqrt_newton_schulz(sigma1.mm(sigma2).unsqueeze(0), 50).squeeze()  
  out = (diff.dot(diff) +  torch.trace(sigma1) + torch.trace(sigma2)
         - 2 * torch.trace(covmean))
  return out


# Calculate Inception Score mean + std given softmax'd logits and number of splits 
示例7
def torch_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
  """Pytorch implementation of the Frechet Distance.
  Taken from https://github.com/bioinf-jku/TTUR
  The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
  and X_2 ~ N(mu_2, C_2) is
          d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
  Stable version by Dougal J. Sutherland.
  Params:
  -- mu1   : Numpy array containing the activations of a layer of the
             inception net (like returned by the function 'get_predictions')
             for generated samples.
  -- mu2   : The sample mean over activations, precalculated on an 
             representive data set.
  -- sigma1: The covariance matrix over activations for generated samples.
  -- sigma2: The covariance matrix over activations, precalculated on an 
             representive data set.
  Returns:
  --   : The Frechet Distance.
  """


  assert mu1.shape == mu2.shape, \
    'Training and test mean vectors have different lengths'
  assert sigma1.shape == sigma2.shape, \
    'Training and test covariances have different dimensions'

  diff = mu1 - mu2
  # Run 50 itrs of newton-schulz to get the matrix sqrt of sigma1 dot sigma2
  covmean = sqrt_newton_schulz(sigma1.mm(sigma2).unsqueeze(0), 50).squeeze()  
  out = (diff.dot(diff) +  torch.trace(sigma1) + torch.trace(sigma2)
         - 2 * torch.trace(covmean))
  return out


# Calculate Inception Score mean + std given softmax'd logits and number of splits 
示例8
def batch_mat2twist(R):
    """ The log map from SO3 to so3

        Calculate the twist vector from Rotation matrix 

        Ethan Eade's lie group note:
        http://ethaneade.com/lie.pdf equation (18)

        @todo: may rename the interface to batch_so3logmap(R)

        function tested in 'test_geometry.py'

        @note: it currently does not consider extreme small values. 
        If you use it as training loss, you may run into problems

    :input
    :param Rotation matrix Bx3x3 \in \SO3 space 
    --------
    :param the twist vector Bx3 \in \so3 space
    """
    B = R.size()[0]
 
    R1 = [torch.trace(R[i]) for i in range(R.size()[0])]
    tr = torch.stack(R1)
    theta = acos( ((tr - 1)/2).clamp(-1,1) )

    r11,r12,r13,r21,r22,r23,r31,r32,r33 = torch.split(R.view(B,-1),1,dim=1)
    res = torch.cat([r32-r23, r13-r31, r21-r12],dim=1)  

    magnitude = (0.5*theta/sin(theta))

    return magnitude.view(B,1) * res 
示例9
def forward(self):

        constrainted_matrix = self.select_param()
        matrix_ = torch.squeeze(torch.squeeze(constrainted_matrix,dim=2),dim=2)
        matrix_t = torch.t(matrix_)
        matrixs = torch.mm(matrix_t,matrix_)
        trace_ = torch.trace(torch.mm(matrixs,torch.inverse(matrixs)))
        log_det = torch.logdet(matrixs)
        maha_loss = trace_ - log_det
        return maha_loss 
示例10
def trace(self, a):
        return torch.trace(a) 
示例11
def block_trace(self, X, m, n):
        blocks = []
        for i in range(n):
            blocks.append([])
            for j in range(n):
                block = self.trace(X[..., i*m:(i+1)*m, j*m:(j+1)*m])
                blocks[-1].append(block)
        return self.pack([
            self.pack([
                b for b in block
            ])
            for block in blocks
        ]) 
示例12
def normalize_factors(A, B):
    eps = 1e-10

    trA = torch.trace(A) + eps
    trB = torch.trace(B) + eps
    assert trA > 0, 'Must PD. A not PD'
    assert trB > 0, 'Must PD. B not PD'
    return A * (trB/trA)**0.5, B * (trA/trB)**0.5 
示例13
def trace(tensor: BKTensor) -> BKTensor:
    new_real = torch.trace(tensor[0])
    new_imag = torch.trace(tensor[1])
    return torch.stack((new_real, new_imag)) 
示例14
def btrace(X):
    # batch-trace: [B, N, N] -> [B]
    n = X.size(-1)
    X_ = X.view(-1, n, n)
    tr = torch.zeros(X_.size(0)).to(X)
    for i in range(tr.size(0)):
        m = X_[i, :, :]
        tr[i] = torch.trace(m)
    return tr.view(*(X.size()[0:-2])) 
示例15
def forward(self, ypred, ytrue):
		# geodesic loss between predicted and gt rotations
		tmp = torch.stack([torch.trace(torch.mm(ypred[i].t(), ytrue[i])) for i in range(ytrue.size(0))])
		angle = torch.acos(torch.clamp((tmp - 1.0) / 2, -1 + eps, 1 - eps))
		return torch.mean(angle) 
示例16
def my_loss(self, ypred, ytrue):
		# geodesic loss between predicted and gt rotations
		tmp = torch.stack([torch.trace(torch.mm(ypred[i].t(), ytrue[i])) for i in range(ytrue.size(0))])
		angle = torch.acos(torch.clamp((tmp - 1.0) / 2, -1 + eps, 1 - eps))
		return torch.mean(angle) 
示例17
def forward(self, ypred, ytrue):
		# geodesic loss between predicted and gt rotations
		tmp = torch.stack([torch.trace(torch.mm(ypred[i].t(), ytrue[i])) for i in range(ytrue.size(0))])
		angle = torch.acos(torch.clamp((tmp - 1.0) / 2, -1 + eps, 1 - eps))
		return torch.mean(angle) 
示例18
def forward(self, adj_mat, l_mat):
        t0 = F.leaky_relu(self.encode0(adj_mat))
        t0 = F.leaky_relu(self.encode1(t0))
        self.embedding = t0
        t0 = F.leaky_relu(self.decode0(t0))
        t0 = F.leaky_relu(self.decode1(t0))
         
        L_1st = 2 * torch.trace(torch.mm(torch.mm(torch.t(self.embedding), l_mat), self.embedding))
        L_2nd = torch.sum(((adj_mat - t0) * adj_mat * self.beta) * ((adj_mat - t0) *  adj_mat * self.beta))
        
        L_reg = 0
        for param in self.parameters():
            L_reg += self.nu1 * torch.sum(torch.abs(param)) + self.nu2 * torch.sum(param * param)
        return self.alpha * L_1st,  L_2nd,  self.alpha * L_1st + L_2nd, L_reg 
示例19
def _inv_covs(self, xxt, ggt, num_locations):
        """Inverses the covariances."""
        # Computes pi
        pi = 1.0
        if self.pi:
            tx = torch.trace(xxt) * ggt.shape[0]
            tg = torch.trace(ggt) * xxt.shape[0]
            pi = (tx / tg)
        # Regularizes and inverse
        eps = self.eps / num_locations
        diag_xxt = xxt.new(xxt.shape[0]).fill_((eps * pi) ** 0.5)
        diag_ggt = ggt.new(ggt.shape[0]).fill_((eps / pi) ** 0.5)
        ixxt = (xxt + torch.diag(diag_xxt)).inverse()
        iggt = (ggt + torch.diag(diag_ggt)).inverse()
        return ixxt, iggt 
示例20
def test_trace():
    mat = torch.arange(1, 10).view(3, 3)
    assert utils.trace(mat)[0] == torch.trace(mat)

    mats = torch.cat([torch.arange(1, 10).view(1, 3, 3),
                      torch.arange(11, 20).view(1, 3, 3)], dim=0)
    traces = utils.trace(mats)
    assert len(traces) == 2 and \
        traces[0] == torch.trace(mats[0]) and \
        traces[1] == torch.trace(mats[1]) 
示例21
def get_semi_orth_weight(conv1dlayer):
        # updates conv1 weight M using update rule to make it more semi orthogonal
        # based off ConstrainOrthonormalInternal in nnet-utils.cc in Kaldi src/nnet3
        # includes the tweaks related to slowing the update speed
        # only an implementation of the 'floating scale' case
        with torch.no_grad():
            update_speed = 0.125
            orig_shape = conv1dlayer.weight.shape
            # a conv weight differs slightly from TDNN formulation:
            # Conv weight: (out_filters, in_filters, kernel_width)
            # TDNN weight M is of shape: (in_dim, out_dim) or [rows, cols]
            # the in_dim of the TDNN weight is equivalent to in_filters * kernel_width of the Conv
            M = conv1dlayer.weight.reshape(
                orig_shape[0], orig_shape[1]*orig_shape[2]).T
            # M now has shape (in_dim[rows], out_dim[cols])
            mshape = M.shape
            if mshape[0] > mshape[1]:  # semi orthogonal constraint for rows > cols
                M = M.T
            P = torch.mm(M, M.T)
            PP = torch.mm(P, P.T)
            trace_P = torch.trace(P)
            trace_PP = torch.trace(PP)
            ratio = trace_PP * P.shape[0] / (trace_P * trace_P)

            # the following is the tweak to avoid divergence (more info in Kaldi)
            assert ratio > 0.99
            if ratio > 1.02:
                update_speed *= 0.5
                if ratio > 1.1:
                    update_speed *= 0.5

            scale2 = trace_PP/trace_P
            update = P - (torch.matrix_power(P, 0) * scale2)
            alpha = update_speed / scale2
            update = (-4.0 * alpha) * torch.mm(update, M)
            updated = M + update
            # updated has shape (cols, rows) if rows > cols, else has shape (rows, cols)
            # Transpose (or not) to shape (cols, rows) (IMPORTANT, s.t. correct dimensions are reshaped)
            # Then reshape to (cols, in_filters, kernel_width)
            return updated.reshape(*orig_shape) if mshape[0] > mshape[1] else updated.T.reshape(*orig_shape) 
示例22
def get_semi_orth_error(conv1dlayer):
        with torch.no_grad():
            orig_shape = conv1dlayer.weight.shape
            M = conv1dlayer.weight.reshape(
                orig_shape[0], orig_shape[1]*orig_shape[2]).T
            mshape = M.shape
            if mshape[0] > mshape[1]:  # semi orthogonal constraint for rows > cols
                M = M.T
            P = torch.mm(M, M.T)
            PP = torch.mm(P, P.T)
            trace_P = torch.trace(P)
            trace_PP = torch.trace(PP)
            scale2 = torch.sqrt(trace_PP/trace_P) ** 2
            update = P - (torch.matrix_power(P, 0) * scale2)
            return torch.norm(update, p='fro') 
示例23
def train(ep):
    model.train()
    total_loss = 0
    count = 0
    train_idx_list = np.arange(len(X_train), dtype="int32")
    np.random.shuffle(train_idx_list)
    for idx in train_idx_list:
        data_line = X_train[idx]
        x, y = Variable(data_line[:-1]), Variable(data_line[1:])
        if args.cuda:
            x, y = x.cuda(), y.cuda()

        optimizer.zero_grad()
        output = model(x.unsqueeze(0)).squeeze(0)
        loss = -torch.trace(torch.matmul(y, torch.log(output).float().t()) +
                            torch.matmul((1 - y), torch.log(1 - output).float().t()))
        total_loss += loss.item()
        count += output.size(0)

        if args.clip > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
        loss.backward()
        optimizer.step()
        if idx > 0 and idx % args.log_interval == 0:
            cur_loss = total_loss / count
            print("Epoch {:2d} | lr {:.5f} | loss {:.5f}".format(ep, lr, cur_loss))
            total_loss = 0.0
            count = 0 
示例24
def numpy_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
  """Numpy implementation of the Frechet Distance.
  Taken from https://github.com/bioinf-jku/TTUR
  The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
  and X_2 ~ N(mu_2, C_2) is
          d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
  Stable version by Dougal J. Sutherland.
  Params:
  -- mu1   : Numpy array containing the activations of a layer of the
             inception net (like returned by the function 'get_predictions')
             for generated samples.
  -- mu2   : The sample mean over activations, precalculated on an 
             representive data set.
  -- sigma1: The covariance matrix over activations for generated samples.
  -- sigma2: The covariance matrix over activations, precalculated on an 
             representive data set.
  Returns:
  --   : The Frechet Distance.
  """

  mu1 = np.atleast_1d(mu1)
  mu2 = np.atleast_1d(mu2)

  sigma1 = np.atleast_2d(sigma1)
  sigma2 = np.atleast_2d(sigma2)

  assert mu1.shape == mu2.shape, \
    'Training and test mean vectors have different lengths'
  assert sigma1.shape == sigma2.shape, \
    'Training and test covariances have different dimensions'

  diff = mu1 - mu2

  # Product might be almost singular
  covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
  if not np.isfinite(covmean).all():
    msg = ('fid calculation produces singular product; '
           'adding %s to diagonal of cov estimates') % eps
    print(msg)
    offset = np.eye(sigma1.shape[0]) * eps
    covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

  # Numerical error might give slight imaginary component
  if np.iscomplexobj(covmean):
    print('wat')
    if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
      m = np.max(np.abs(covmean.imag))
      raise ValueError('Imaginary component {}'.format(m))
    covmean = covmean.real  

  tr_covmean = np.trace(covmean) 

  out = diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
  return out 
示例25
def numpy_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
  """Numpy implementation of the Frechet Distance.
  Taken from https://github.com/bioinf-jku/TTUR
  The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
  and X_2 ~ N(mu_2, C_2) is
          d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
  Stable version by Dougal J. Sutherland.
  Params:
  -- mu1   : Numpy array containing the activations of a layer of the
             inception net (like returned by the function 'get_predictions')
             for generated samples.
  -- mu2   : The sample mean over activations, precalculated on an 
             representive data set.
  -- sigma1: The covariance matrix over activations for generated samples.
  -- sigma2: The covariance matrix over activations, precalculated on an 
             representive data set.
  Returns:
  --   : The Frechet Distance.
  """

  mu1 = np.atleast_1d(mu1)
  mu2 = np.atleast_1d(mu2)

  sigma1 = np.atleast_2d(sigma1)
  sigma2 = np.atleast_2d(sigma2)

  assert mu1.shape == mu2.shape, \
    'Training and test mean vectors have different lengths'
  assert sigma1.shape == sigma2.shape, \
    'Training and test covariances have different dimensions'

  diff = mu1 - mu2

  # Product might be almost singular
  covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
  if not np.isfinite(covmean).all():
    msg = ('fid calculation produces singular product; '
           'adding %s to diagonal of cov estimates') % eps
    print(msg)
    offset = np.eye(sigma1.shape[0]) * eps
    covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

  # Numerical error might give slight imaginary component
  if np.iscomplexobj(covmean):
    print('wat')
    if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
      m = np.max(np.abs(covmean.imag))
      raise ValueError('Imaginary component {}'.format(m))
    covmean = covmean.real  

  tr_covmean = np.trace(covmean) 

  out = diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
  return out 
示例26
def BilateralRefSmoothnessLoss(self, pred_R, targets, att, num_features):
        # pred_R = pred_R.cpu()
        total_loss = Variable(torch.cuda.FloatTensor(1))
        total_loss[0] = 0
        N = pred_R.size(2) * pred_R.size(3)
        Z = (pred_R.size(1) * N )

        # grad_input = torch.FloatTensor(pred_R.size())
        # grad_input = grad_input.zero_()

        for i in range(pred_R.size(0)): # for each image
            B_mat = targets[att+'B_list'][i] # still list of blur sparse matrices 
            S_mat = Variable(targets[att + 'S'][i].cuda(), requires_grad = False) # Splat and Slicing matrix
            n_vec = Variable(targets[att + 'N'][i].cuda(), requires_grad = False) # bi-stochatistic vector, which is diagonal matrix

            p = pred_R[i,:,:,:].view(pred_R.size(1),-1).t() # NX3
            # p'p
            # p_norm = torch.mm(p.t(), p) 
            # p_norm_sum = torch.trace(p_norm)
            p_norm_sum = torch.sum(torch.mul(p,p))

            # S * N * p
            Snp = torch.mul(n_vec.repeat(1,pred_R.size(1)), p)
            sp_mm = Sparse()
            Snp = sp_mm(Snp, S_mat)

            Snp_1 = Snp.clone()
            Snp_2 = Snp.clone()

            # # blur
            for f in range(num_features+1):
                B_var1 = Variable(B_mat[f].cuda(), requires_grad = False)
                sp_mm1 = Sparse()
                Snp_1 = sp_mm1(Snp_1, B_var1)
                
                B_var2 = Variable(B_mat[num_features-f].cuda(), requires_grad = False)               
                sp_mm2 = Sparse()
                Snp_2 = sp_mm2(Snp_2, B_var2)

            Snp_12 = Snp_1 + Snp_2
            pAp = torch.sum(torch.mul(Snp, Snp_12))


            total_loss = total_loss + ((p_norm_sum - pAp)/Z)

        total_loss = total_loss/pred_R.size(0) 
        # average over all images
        return total_loss 
示例27
def lw_shrink(X_t):
    """
    Estimates the shrunk Ledoit-Wolf covariance matrix

    Args:
      X_t: samples x variants

    Returns:
      shrunk_cov_t: shrunk covariance
      shrinkage_t:  shrinkage coefficient

    Adapted from scikit-learn:
    https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/covariance/shrunk_covariance_.py
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if len(X_t.shape) == 2:
        n_samples, n_features = X_t.shape  # samples x variants
        X_t = X_t - X_t.mean(0)
        X2_t = X_t.pow(2)
        emp_cov_trace_sum = X2_t.sum() / n_samples
        delta_ = torch.mm(X_t.t(), X_t).pow(2).sum() / n_samples**2
        beta_ = torch.mm(X2_t.t(), X2_t).sum()
        beta = 1. / (n_features * n_samples) * (beta_ / n_samples - delta_)
        delta = delta_ - 1. * emp_cov_trace_sum**2 / n_features
        delta /= n_features
        beta = torch.min(beta, delta)
        shrinkage_t = 0 if beta == 0 else beta / delta
        emp_cov_t = torch.mm(X_t.t(), X_t) / n_samples
        mu_t = torch.trace(emp_cov_t) / n_features
        shrunk_cov_t = (1. - shrinkage_t) * emp_cov_t
        shrunk_cov_t.view(-1)[::n_features + 1] += shrinkage_t * mu_t  # add to diagonal
    else:  # broadcast along first dimension
        n_samples, n_features = X_t.shape[1:]  # samples x variants
        X_t = X_t - X_t.mean(1, keepdim=True)
        X2_t = X_t.pow(2)
        emp_cov_trace_sum = X2_t.sum([1,2]) / n_samples
        delta_ = torch.matmul(torch.transpose(X_t, 1, 2), X_t).pow(2).sum([1,2]) / n_samples**2
        beta_ = torch.matmul(torch.transpose(X2_t, 1, 2), X2_t).sum([1,2])
        beta = 1. / (n_features * n_samples) * (beta_ / n_samples - delta_)
        delta = delta_ - 1. * emp_cov_trace_sum**2 / n_features
        delta /= n_features
        beta = torch.min(beta, delta)
        shrinkage_t = torch.where(beta==0, torch.zeros(beta.shape).to(device), beta/delta)
        emp_cov_t = torch.matmul(torch.transpose(X_t, 1, 2), X_t) / n_samples
        mu_t = torch.diagonal(emp_cov_t, dim1=1, dim2=2).sum(1) / n_features
        shrunk_cov_t = (1 - shrinkage_t.reshape([shrinkage_t.shape[0], 1, 1])) * emp_cov_t

        ix = torch.LongTensor(np.array([np.arange(0, n_features**2, n_features+1)+i*n_features**2 for i in range(X_t.shape[0])])).to(device)
        shrunk_cov_t.view(-1)[ix] += (shrinkage_t * mu_t).unsqueeze(-1)  # add to diagonal

    return shrunk_cov_t, shrinkage_t