Python源码示例:torch.diag_embed()

示例1
def init_action_pd(ActionPD, pdparam):
    '''
    Initialize the action_pd for discrete or continuous actions:
    - discrete: action_pd = ActionPD(logits)
    - continuous: action_pd = ActionPD(loc, scale)
    '''
    if 'logits' in ActionPD.arg_constraints:  # discrete
        action_pd = ActionPD(logits=pdparam)
    else:  # continuous, args = loc and scale
        if isinstance(pdparam, list):  # split output
            loc, scale = pdparam
        else:
            loc, scale = pdparam.transpose(0, 1)
        # scale (stdev) must be > 0, use softplus with positive
        scale = F.softplus(scale) + 1e-8
        if isinstance(pdparam, list):  # split output
            # construct covars from a batched scale tensor
            covars = torch.diag_embed(scale)
            action_pd = ActionPD(loc=loc, covariance_matrix=covars)
        else:
            action_pd = ActionPD(loc=loc, scale=scale)
    return action_pd 
示例2
def __init__(self, nnodes, nfeat, nhid, nclass, gamma=1.0, beta1=5e-4, beta2=5e-4, lr=0.01, dropout=0.6, device='cpu'):
        super(RGCN, self).__init__()

        self.device = device
        # adj_norm = normalize(adj)
        # first turn original features to distribution
        self.lr = lr
        self.gamma = gamma
        self.beta1 = beta1
        self.beta2 = beta2
        self.nclass = nclass
        self.nhid = nhid // 2
        # self.gc1 = GaussianConvolution(nfeat, nhid, dropout=dropout)
        # self.gc2 = GaussianConvolution(nhid, nclass, dropout)
        self.gc1 = GGCL_F(nfeat, nhid, dropout=dropout)
        self.gc2 = GGCL_D(nhid, nclass, dropout=dropout)

        self.dropout = dropout
        # self.gaussian = MultivariateNormal(torch.zeros(self.nclass), torch.eye(self.nclass))
        self.gaussian = MultivariateNormal(torch.zeros(nnodes, self.nclass),
                torch.diag_embed(torch.ones(nnodes, self.nclass)))
        self.adj_norm1, self.adj_norm2 = None, None
        self.features, self.labels = None, None 
示例3
def from_log_cholesky(cls,
                          log_diag: torch.Tensor,
                          off_diag: torch.Tensor,
                          **kwargs) -> 'Covariance':

        assert log_diag.shape[:-1] == off_diag.shape[:-1]
        batch_dim = log_diag.shape[:-1]

        rank = log_diag.shape[-1]
        L = torch.diag_embed(torch.exp(log_diag))

        idx = 0
        for i in range(rank):
            for j in range(i):
                L[..., i, j] = off_diag[..., idx]
                idx += 1

        out = cls(size=batch_dim + (rank, rank))
        if kwargs:
            out = out.to(**kwargs)
        perm_shape = tuple(range(len(batch_dim))) + (-1, -2)
        out[:] = L.matmul(L.permute(perm_shape))
        return out 
示例4
def init_action_pd(ActionPD, pdparam):
    '''
    Initialize the action_pd for discrete or continuous actions:
    - discrete: action_pd = ActionPD(logits)
    - continuous: action_pd = ActionPD(loc, scale)
    '''
    args = ActionPD.arg_constraints
    if 'logits' in args:  # discrete
        # for relaxed discrete dist. with reparametrizable discrete actions
        pd_kwargs = {'temperature': torch.tensor(1.0)} if hasattr(ActionPD, 'temperature') else {}
        action_pd = ActionPD(logits=pdparam, **pd_kwargs)
    else:  # continuous, args = loc and scale
        if isinstance(pdparam, list):  # split output
            loc, scale = pdparam
        else:
            loc, scale = pdparam.transpose(0, 1)
        # scale (stdev) must be > 0, log-clamp-exp
        scale = torch.clamp(scale, min=-20, max=2).exp()
        if 'covariance_matrix' in args:  # split output
            # construct covars from a batched scale tensor
            covars = torch.diag_embed(scale)
            action_pd = ActionPD(loc=loc, covariance_matrix=covars)
        else:
            action_pd = ActionPD(loc=loc, scale=scale)
    return action_pd 
示例5
def __local_curvatures(self, module, g_inp, g_out):
        if self.derivatives.hessian_is_zero():
            return []
        if not self.derivatives.hessian_is_diagonal():
            raise NotImplementedError

        def positive_part(sign, H):
            return clamp(sign * H, min=0)

        def diag_embed_multi_dim(H):
            """Convert [N, C_in, H_in, ...] to [N, C_in * H_in * ...,],
            embed into [N, C_in * H_in * ..., C_in * H_in = V], convert back
            to [V, N, C_in, H_in, ...,  V]."""
            feature_shapes = H.shape[1:]
            V, N = prod(feature_shapes), H.shape[0]

            H_diag = diag_embed(H.view(N, V))
            # [V, N, C_in, H_in, ...]
            shape = (V, N, *feature_shapes)
            return einsum("nic->cni", H_diag).view(shape)

        def decompose_into_positive_and_negative_sqrt(H):
            return [
                [diag_embed_multi_dim(positive_part(sign, H).sqrt_()), sign]
                for sign in [self.PLUS, self.MINUS]
            ]

        H = self.derivatives.hessian_diagonal(module, g_inp, g_out)
        return decompose_into_positive_and_negative_sqrt(H) 
示例6
def _sqrt_hessian(self, module, g_inp, g_out):
        self._check_2nd_order_parameters(module)

        probs = self._get_probs(module)
        tau = torchsqrt(probs)
        V_dim, C_dim = 0, 2
        Id = diag_embed(ones_like(probs), dim1=V_dim, dim2=C_dim)
        Id_tautau = Id - einsum("nv,nc->vnc", tau, tau)
        sqrt_H = einsum("nc,vnc->vnc", tau, Id_tautau)

        if module.reduction == "mean":
            N = module.input0.shape[0]
            sqrt_H /= sqrt(N)

        return sqrt_H 
示例7
def get_laplacian_nuc_norm(self, A: 'N x C x S'):

        N, C, _ = A.size()
        # print(A)
        AAT = torch.bmm(A, A.permute(0, 2, 1))
        ones = torch.ones((N, C, 1), device='cuda')
        D = torch.bmm(AAT, ones).view(N, C)
        D = torch.diag_embed(D)

        return nuclear_norm(D - AAT, sym=True).sum() / N 
示例8
def evaluate(self, state, action):   
        action_mean = self.actor(state)
        
        action_var = self.action_var.expand_as(action_mean)
        cov_mat = torch.diag_embed(action_var).to(device)
        
        dist = MultivariateNormal(action_mean, cov_mat)
        
        action_logprobs = dist.log_prob(action)
        dist_entropy = dist.entropy()
        state_value = self.critic(state)
        
        return action_logprobs, torch.squeeze(state_value), dist_entropy 
示例9
def evaluate_lazy_tensor(self, lazy_tensor):
        diag = lazy_tensor._diag_tensor._diag
        tensor = lazy_tensor._lazy_tensor.tensor
        return tensor + torch.diag_embed(diag, dim1=-2, dim2=-1) 
示例10
def evaluate_lazy_tensor(self, lazy_tensor):
        diag = lazy_tensor._diag_tensor._diag
        tensor = lazy_tensor._lazy_tensor.tensor
        return tensor + torch.diag_embed(diag, dim1=-2, dim2=-1) 
示例11
def evaluate(self):
        if self._diag.dim() == 0:
            return self._diag
        return torch.diag_embed(self._diag) 
示例12
def _eval_corr_matrix(self):
        tnc = self.task_noise_corr
        fac_diag = torch.ones(*tnc.shape[:-1], self.num_tasks, device=tnc.device, dtype=tnc.dtype)
        Cfac = torch.diag_embed(fac_diag)
        Cfac[..., self.tidcs[0], self.tidcs[1]] = self.task_noise_corr
        # squared rows must sum to one for this to be a correlation matrix
        C = Cfac / Cfac.pow(2).sum(dim=-1, keepdim=True).sqrt()
        return C @ C.transpose(-1, -2) 
示例13
def _create_marginal_input(self, batch_shape=torch.Size()):
        mat = torch.randn(*batch_shape, 5, 5)
        eye = torch.diag_embed(torch.ones(*batch_shape, 5))
        return MultivariateNormal(torch.randn(*batch_shape, 5), mat @ mat.transpose(-1, -2) + eye) 
示例14
def matrix(self):
        """Matrix form of the butterfly matrix
        """
        if not self.complex:
            return (torch.diag(self.diag)
                    + torch.diag(self.subdiag, -self.diagonal)
                    + torch.diag(self.superdiag, self.diagonal))
        else: # Use torch.diag_embed (available in Pytorch 1.0) to deal with complex case.
            return (torch.diag_embed(self.diag.t(), dim1=0, dim2=1)
                    + torch.diag_embed(self.subdiag.t(), -self.diagonal, dim1=0, dim2=1)
                    + torch.diag_embed(self.superdiag.t(), self.diagonal, dim1=0, dim2=1)) 
示例15
def _get_test_posterior(shape, device, dtype, interleaved=True, lazy=False):
    mean = torch.rand(shape, device=device, dtype=dtype)
    n_covar = shape[-2:].numel()
    diag = torch.rand(shape, device=device, dtype=dtype)
    diag = diag.view(*diag.shape[:-2], n_covar)
    a = torch.rand(*shape[:-2], n_covar, n_covar, device=device, dtype=dtype)
    covar = a @ a.transpose(-1, -2) + torch.diag_embed(diag)
    if lazy:
        covar = NonLazyTensor(covar)
    if shape[-1] == 1:
        mvn = MultivariateNormal(mean.squeeze(-1), covar)
    else:
        mvn = MultitaskMultivariateNormal(mean, covar, interleaved=interleaved)
    return GPyTorchPosterior(mvn) 
示例16
def test_lognorm_to_norm(self):
        for dtype in (torch.float, torch.double):

            # independent case
            mu = torch.tensor([0.25, 0.5, 1.0], device=self.device, dtype=dtype)
            diag = torch.tensor([0.5, 2.0, 1.0], device=self.device, dtype=dtype)
            Cov = torch.diag_embed((math.exp(1) - 1) * diag)
            mu_n, Cov_n = lognorm_to_norm(mu, Cov)
            mu_n_expected = torch.tensor(
                [-2.73179, -2.03864, -0.5], device=self.device, dtype=dtype
            )
            diag_expected = torch.tensor(
                [2.69099, 2.69099, 1.0], device=self.device, dtype=dtype
            )
            self.assertTrue(torch.allclose(mu_n, mu_n_expected))
            self.assertTrue(torch.allclose(Cov_n, torch.diag_embed(diag_expected)))

            # correlated case
            Z = torch.zeros(3, 3, device=self.device, dtype=dtype)
            Z[0, 2] = math.sqrt(math.exp(1)) - 1
            Z[2, 0] = math.sqrt(math.exp(1)) - 1
            mu = torch.ones(3, device=self.device, dtype=dtype)
            Cov = torch.diag_embed(mu * (math.exp(1) - 1)) + Z
            mu_n, Cov_n = lognorm_to_norm(mu, Cov)
            mu_n_expected = -0.5 * torch.ones(3, device=self.device, dtype=dtype)
            Cov_n_expected = torch.tensor(
                [[1.0, 0.0, 0.5], [0.0, 1.0, 0.0], [0.5, 0.0, 1.0]],
                device=self.device,
                dtype=dtype,
            )
            self.assertTrue(torch.allclose(mu_n, mu_n_expected, atol=1e-4))
            self.assertTrue(torch.allclose(Cov_n, Cov_n_expected, atol=1e-4)) 
示例17
def test_norm_to_lognorm(self):
        for dtype in (torch.float, torch.double):

            # Test joint, independent
            expmu = torch.tensor([1.0, 2.0, 3.0], device=self.device, dtype=dtype)
            expdiag = torch.tensor([1.5, 2.0, 3], device=self.device, dtype=dtype)
            mu = torch.log(expmu)
            diag = torch.log(expdiag)
            Cov = torch.diag_embed(diag)
            mu_ln, Cov_ln = norm_to_lognorm(mu, Cov)
            mu_ln_expected = expmu * torch.exp(0.5 * diag)
            diag_ln_expected = torch.tensor(
                [0.75, 8.0, 54.0], device=self.device, dtype=dtype
            )
            Cov_ln_expected = torch.diag_embed(diag_ln_expected)
            self.assertTrue(torch.allclose(Cov_ln, Cov_ln_expected))
            self.assertTrue(torch.allclose(mu_ln, mu_ln_expected))

            # Test joint, correlated
            Cov[0, 2] = 0.1
            Cov[2, 0] = 0.1
            mu_ln, Cov_ln = norm_to_lognorm(mu, Cov)
            Cov_ln_expected[0, 2] = 0.669304
            Cov_ln_expected[2, 0] = 0.669304
            self.assertTrue(torch.allclose(Cov_ln, Cov_ln_expected))
            self.assertTrue(torch.allclose(mu_ln, mu_ln_expected))

            # Test marginal
            mu = torch.tensor([-1.0, 0.0, 1.0], device=self.device, dtype=dtype)
            v = torch.tensor([1.0, 2.0, 3.0], device=self.device, dtype=dtype)
            var = 2 * (torch.log(v) - mu)
            mu_ln = norm_to_lognorm_mean(mu, var)
            var_ln = norm_to_lognorm_variance(mu, var)
            mu_ln_expected = torch.tensor(
                [1.0, 2.0, 3.0], device=self.device, dtype=dtype
            )
            var_ln_expected = (torch.exp(var) - 1) * mu_ln_expected ** 2
            self.assertTrue(torch.allclose(mu_ln, mu_ln_expected))
            self.assertTrue(torch.allclose(var_ln, var_ln_expected)) 
示例18
def test_round_trip(self):
        for dtype in (torch.float, torch.double):
            for batch_shape in ([], [2]):
                mu = 5 + torch.rand(*batch_shape, 4, device=self.device, dtype=dtype)
                a = 0.2 * torch.randn(
                    *batch_shape, 4, 4, device=self.device, dtype=dtype
                )
                diag = 3.0 + 2 * torch.rand(
                    *batch_shape, 4, device=self.device, dtype=dtype
                )
                Cov = a @ a.transpose(-1, -2) + torch.diag_embed(diag)
                mu_n, Cov_n = lognorm_to_norm(mu, Cov)
                mu_rt, Cov_rt = norm_to_lognorm(mu_n, Cov_n)
                self.assertTrue(torch.allclose(mu_rt, mu, atol=1e-4))
                self.assertTrue(torch.allclose(Cov_rt, Cov, atol=1e-4)) 
示例19
def tobit_adjustment(mean: Tensor,
                     cov: Tensor,
                     lower: Optional[Tensor] = None,
                     upper: Optional[Tensor] = None,
                     probs: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
    assert cov.shape[-1] == cov.shape[-2]  # symmetrical

    if upper is None:
        upper = torch.full_like(mean, float('inf'))
    if lower is None:
        lower = torch.full_like(mean, -float('inf'))

    assert lower.shape == upper.shape == mean.shape

    is_cens_up = torch.isfinite(upper)
    is_cens_lo = torch.isfinite(lower)

    if not is_cens_up.any() and not is_cens_lo.any():
        return mean, cov

    F1, F2 = _F1F2(mean, cov, lower, upper)

    std = torch.diagonal(cov, dim1=-2, dim2=-1).sqrt()
    sqrt_pi = pi ** .5

    # prob censoring:
    if probs is None:
        prob_lo, prob_up = tobit_probs(mean=mean,
                                       cov=cov,
                                       lower=lower,
                                       upper=upper)
    else:
        prob_lo, prob_up = probs

    # adjust mean:
    lower_adj = torch.zeros_like(mean)
    lower_adj[is_cens_lo] = prob_lo[is_cens_lo] * lower[is_cens_lo]
    upper_adj = torch.zeros_like(mean)
    upper_adj[is_cens_up] = prob_up[is_cens_up] * upper[is_cens_up]
    mean_if_uncens = mean + (sqrt(2. / pi) * F1) * std
    mean_uncens_adj = (1. - prob_up - prob_lo) * mean_if_uncens
    mean_adj = mean_uncens_adj + upper_adj + lower_adj

    # adjust cov:
    diag_adj = torch.zeros_like(mean)
    for m in range(mean.shape[-1]):
        diag_adj[..., m] = (1. + 2. / sqrt_pi * F2[..., m] - 2. / pi * (F1[..., m] ** 2)) * cov[..., m, m]

    cov_adj = torch.diag_embed(diag_adj)

    return mean_adj, cov_adj 
示例20
def _update_group(self,
                      obs: Tensor,
                      group_idx: Union[slice, Sequence[int]],
                      which_valid: Union[slice, Sequence[int]],
                      lower: Optional[Tensor] = None,
                      upper: Optional[Tensor] = None
                      ) -> Tuple[Tensor, Tensor]:
        # indices:
        idx_2d = bmat_idx(group_idx, which_valid)
        idx_3d = bmat_idx(group_idx, which_valid, which_valid)

        # observed values, censoring limits
        obs = obs[idx_2d]
        if lower is None:
            lower = torch.full_like(obs, -float('inf'))
        else:
            lower = lower[idx_2d]
            if torch.isnan(lower).any():
                raise ValueError("NaNs not allowed in `lower`")
        if upper is None:
            upper = torch.full_like(obs, float('inf'))
        else:
            upper = upper[idx_2d]
            if torch.isnan(upper).any():
                raise ValueError("NaNs not allowed in `upper`")

        if (lower == upper).any():
            raise RuntimeError("lower cannot == upper")

        # subset belief / design-mats:
        means = self.means[group_idx]
        covs = self.covs[group_idx]
        R = self.R[idx_3d]
        H = self.H[idx_2d]
        measured_means = H.matmul(means.unsqueeze(-1)).squeeze(-1)

        # calculate censoring fx:
        prob_lo, prob_up = tobit_probs(mean=measured_means,
                                       cov=R,
                                       lower=lower,
                                       upper=upper)
        prob_obs = torch.diag_embed(1 - prob_up - prob_lo)

        mm_adj, R_adj = tobit_adjustment(mean=measured_means,
                                         cov=R,
                                         lower=lower,
                                         upper=upper,
                                         probs=(prob_lo, prob_up))

        # kalman gain:
        K = self.kalman_gain(covariance=covs, H=H, R_adjusted=R_adj, prob_obs=prob_obs)

        # update
        means_new = self.mean_update(mean=means, K=K, residuals=obs - mm_adj)
        covs_new = self.covariance_update(covariance=covs, K=K, H=H, prob_obs=prob_obs)
        return means_new, covs_new 
示例21
def test_transformed_posterior(self):
        for dtype in (torch.float, torch.double):
            for m in (1, 2):
                shape = torch.Size([3, m])
                mean = torch.rand(shape, dtype=dtype, device=self.device)
                variance = 1 + torch.rand(shape, dtype=dtype, device=self.device)
                if m == 1:
                    covar = torch.diag_embed(variance.squeeze(-1))
                    mvn = MultivariateNormal(mean.squeeze(-1), lazify(covar))
                else:
                    covar = torch.diag_embed(variance.view(*variance.shape[:-2], -1))
                    mvn = MultitaskMultivariateNormal(mean, lazify(covar))
                p_base = GPyTorchPosterior(mvn=mvn)
                p_tf = TransformedPosterior(  # dummy transforms
                    posterior=p_base,
                    sample_transform=lambda s: s + 2,
                    mean_transform=lambda m, v: 2 * m + v,
                    variance_transform=lambda m, v: m + 2 * v,
                )
                # mean, variance
                self.assertEqual(p_tf.device.type, self.device.type)
                self.assertTrue(p_tf.dtype == dtype)
                self.assertEqual(p_tf.event_shape, shape)
                self.assertTrue(torch.equal(p_tf.mean, 2 * mean + variance))
                self.assertTrue(torch.equal(p_tf.variance, mean + 2 * variance))
                # rsample
                samples = p_tf.rsample()
                self.assertEqual(samples.shape, torch.Size([1]) + shape)
                samples = p_tf.rsample(sample_shape=torch.Size([4]))
                self.assertEqual(samples.shape, torch.Size([4]) + shape)
                samples2 = p_tf.rsample(sample_shape=torch.Size([4, 2]))
                self.assertEqual(samples2.shape, torch.Size([4, 2]) + shape)
                # rsample w/ base samples
                base_samples = torch.randn(4, *shape, device=self.device, dtype=dtype)
                # incompatible shapes
                with self.assertRaises(RuntimeError):
                    p_tf.rsample(
                        sample_shape=torch.Size([3]), base_samples=base_samples
                    )
                # make sure sample transform is applied correctly
                samples_base = p_base.rsample(
                    sample_shape=torch.Size([4]), base_samples=base_samples
                )
                samples_tf = p_tf.rsample(
                    sample_shape=torch.Size([4]), base_samples=base_samples
                )
                self.assertTrue(torch.equal(samples_tf, samples_base + 2))
                # check error handling
                p_tf_2 = TransformedPosterior(
                    posterior=p_base, sample_transform=lambda s: s + 2
                )
                with self.assertRaises(NotImplementedError):
                    p_tf_2.mean
                with self.assertRaises(NotImplementedError):
                    p_tf_2.variance 
示例22
def test_GPyTorchPosterior_Multitask(self):
        for dtype in (torch.float, torch.double):
            mean = torch.rand(3, 2, dtype=dtype, device=self.device)
            variance = 1 + torch.rand(3, 2, dtype=dtype, device=self.device)
            covar = variance.view(-1).diag()
            mvn = MultitaskMultivariateNormal(mean, lazify(covar))
            posterior = GPyTorchPosterior(mvn=mvn)
            # basics
            self.assertEqual(posterior.device.type, self.device.type)
            self.assertTrue(posterior.dtype == dtype)
            self.assertEqual(posterior.event_shape, torch.Size([3, 2]))
            self.assertTrue(torch.equal(posterior.mean, mean))
            self.assertTrue(torch.equal(posterior.variance, variance))
            # rsample
            samples = posterior.rsample(sample_shape=torch.Size([4]))
            self.assertEqual(samples.shape, torch.Size([4, 3, 2]))
            samples2 = posterior.rsample(sample_shape=torch.Size([4, 2]))
            self.assertEqual(samples2.shape, torch.Size([4, 2, 3, 2]))
            # rsample w/ base samples
            base_samples = torch.randn(4, 3, 2, device=self.device, dtype=dtype)
            samples_b1 = posterior.rsample(
                sample_shape=torch.Size([4]), base_samples=base_samples
            )
            samples_b2 = posterior.rsample(
                sample_shape=torch.Size([4]), base_samples=base_samples
            )
            self.assertTrue(torch.allclose(samples_b1, samples_b2))
            base_samples2 = torch.randn(4, 2, 3, 2, device=self.device, dtype=dtype)
            samples2_b1 = posterior.rsample(
                sample_shape=torch.Size([4, 2]), base_samples=base_samples2
            )
            samples2_b2 = posterior.rsample(
                sample_shape=torch.Size([4, 2]), base_samples=base_samples2
            )
            self.assertTrue(torch.allclose(samples2_b1, samples2_b2))
            # collapse_batch_dims
            b_mean = torch.rand(2, 3, 2, dtype=dtype, device=self.device)
            b_variance = 1 + torch.rand(2, 3, 2, dtype=dtype, device=self.device)
            b_covar = torch.diag_embed(b_variance.view(2, 6))
            b_mvn = MultitaskMultivariateNormal(b_mean, lazify(b_covar))
            b_posterior = GPyTorchPosterior(mvn=b_mvn)
            b_base_samples = torch.randn(4, 1, 3, 2, device=self.device, dtype=dtype)
            b_samples = b_posterior.rsample(
                sample_shape=torch.Size([4]), base_samples=b_base_samples
            )
            self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 2])) 
示例23
def _get_test_posterior(
    batch_shape: torch.Size,
    q: int = 1,
    m: int = 1,
    interleaved: bool = True,
    lazy: bool = False,
    independent: bool = False,
    **tkwargs
) -> GPyTorchPosterior:
    r"""Generate a Posterior for testing purposes.

    Args:
        batch_shape: The batch shape of the data.
        q: The number of candidates
        m: The number of outputs.
        interleaved: A boolean indicating the format of the
            MultitaskMultivariateNormal
        lazy: A boolean indicating if the posterior should be lazy
        indepedent: A boolean indicating whether the outputs are independent
        tkwargs: `device` and `dtype` tensor constructor kwargs.


    """
    if independent:
        mvns = []
        for _ in range(m):
            mean = torch.rand(*batch_shape, q, **tkwargs)
            a = torch.rand(*batch_shape, q, q, **tkwargs)
            covar = a @ a.transpose(-1, -2)
            flat_diag = torch.rand(*batch_shape, q, **tkwargs)
            covar = covar + torch.diag_embed(flat_diag)
            mvns.append(MultivariateNormal(mean, covar))
        mtmvn = MultitaskMultivariateNormal.from_independent_mvns(mvns)
    else:
        mean = torch.rand(*batch_shape, q, m, **tkwargs)
        a = torch.rand(*batch_shape, q * m, q * m, **tkwargs)
        covar = a @ a.transpose(-1, -2)
        flat_diag = torch.rand(*batch_shape, q * m, **tkwargs)
        if lazy:
            covar = AddedDiagLazyTensor(covar, DiagLazyTensor(flat_diag))
        else:
            covar = covar + torch.diag_embed(flat_diag)
        mtmvn = MultitaskMultivariateNormal(mean, covar, interleaved=interleaved)
    return GPyTorchPosterior(mtmvn)