Python源码示例:torch.gesv()

示例1
def inverse_no_cache(self, inputs):
        """Cost:
            output = O(D^3 + D^2N)
            logabsdet = O(D^3)
        where:
            D = num of features
            N = num of inputs
        """
        batch_size = inputs.shape[0]
        outputs = inputs - self.bias
        outputs, lu = torch.gesv(outputs.t(), self._weight)  # Linear-system solver.
        outputs = outputs.t()
        # The linear-system solver returns the LU decomposition of the weights, which we
        # can use to obtain the log absolute determinant directly.
        logabsdet = -torch.sum(torch.log(torch.abs(torch.diag(lu))))
        logabsdet = logabsdet * torch.ones(batch_size)
        return outputs, logabsdet 
示例2
def update_cov(self, J_cor, u_odo):
		H, J = self.compute_jac_update(u_odo)
		# add corrected Jacobian on measurement noise covariance

		R = torch.zeros(J.shape[1], J.shape[1])
		for i in range(J_cor.shape[0]):
			J_i = J[i]
			J_i[:, :6] += J_cor[i]
			R += J_i.mm(self.R).mm(J_i.t())

		S = H.mm(self.P).mm(H.t()) + R
		K_prefix = self.P.mm(H.t())
		Id = torch.eye(self.P.shape[-1])
		ImKH = Id - K_prefix.mm(torch.gesv(H, S)[0])
		# *Joseph form* of covariance update for numerical stability.
		self.P = ImKH.mm(self.P).mm(ImKH.transpose(-1, -2)) \
			+ K_prefix.mm(torch.gesv((K_prefix.mm(torch.gesv(R, S)[0])).transpose(-1, -2), S)[0])
		return K_prefix, S 
示例3
def state_and_cov_update(Rot, v, p, b_omega, b_acc, Rot_c_i, t_c_i, P, H, r, R):
        S = H.mm(P).mm(H.t()) + R
        Kt, _ = torch.gesv(P.mm(H.t()).t(), S)
        K = Kt.t()
        dx = K.mv(r.view(-1))

        dR, dxi = TORCHIEKF.sen3exp(dx[:9])
        dv = dxi[:, 0]
        dp = dxi[:, 1]
        Rot_up = dR.mm(Rot)
        v_up = dR.mv(v) + dv
        p_up = dR.mv(p) + dp

        b_omega_up = b_omega + dx[9:12]
        b_acc_up = b_acc + dx[12:15]

        dR = TORCHIEKF.so3exp(dx[15:18])
        Rot_c_i_up = dR.mm(Rot_c_i)
        t_c_i_up = t_c_i + dx[18:21]

        I_KH = TORCHIEKF.IdP - K.mm(H)
        P_upprev = I_KH.mm(P).mm(I_KH.t()) + K.mm(R).mm(K.t())
        P_up = (P_upprev + P_upprev.t())/2
        return Rot_up, v_up, p_up, b_omega_up, b_acc_up, Rot_c_i_up, t_c_i_up, P_up 
示例4
def weight_inverse_and_logabsdet(self):
        """
        Cost:
            inverse = O(D^3)
            logabsdet = O(D)
        where:
            D = num of features
        """
        # If both weight inverse and logabsdet are needed, it's cheaper to compute both together.
        identity = torch.eye(self.features, self.features)
        weight_inv, lu = torch.gesv(identity, self._weight)  # Linear-system solver.
        logabsdet = torch.sum(torch.log(torch.abs(torch.diag(lu))))
        return weight_inv, logabsdet 
示例5
def pre_factor_kkt(Q, G, A):
    """ Perform all one-time factorizations and cache relevant matrix products"""
    nineq, nz, neq, _ = get_sizes(G, A)

    # S = [ A Q^{-1} A^T        A Q^{-1} G^T           ]
    #     [ G Q^{-1} A^T        G Q^{-1} G^T + D^{-1} ]

    U_Q = torch.potrf(Q)
    # partial cholesky of S matrix
    U_S = torch.zeros(neq + nineq, neq + nineq).type_as(Q)

    G_invQ_GT = torch.mm(G, torch.potrs(G.t(), U_Q))
    R = G_invQ_GT
    if neq > 0:
        invQ_AT = torch.potrs(A.t(), U_Q)
        A_invQ_AT = torch.mm(A, invQ_AT)
        G_invQ_AT = torch.mm(G, invQ_AT)

        # TODO: torch.potrf sometimes says the matrix is not PSD but
        # numpy does? I filed an issue at
        # https://github.com/pytorch/pytorch/issues/199
        try:
            U11 = torch.potrf(A_invQ_AT)
        except:
            U11 = torch.Tensor(np.linalg.cholesky(
                A_invQ_AT.cpu().numpy())).type_as(A_invQ_AT)

        # TODO: torch.trtrs is currently not implemented on the GPU
        # and we are using gesv as a workaround.
        U12 = torch.gesv(G_invQ_AT.t(), U11.t())[0]
        U_S[:neq, :neq] = U11
        U_S[:neq, neq:] = U12
        R -= torch.mm(U12.t(), U12)

    return U_Q, U_S, R 
示例6
def binv(b_mat):
    """
    Computes an inverse of each matrix in the batch.
    Pytorch 0.4.1 does not support batched matrix inverse.
    Hence, we are solving AX=I.
    
    Parameters:
      b_mat:  a (n_batch, n, n) Tensor.
    Returns: a (n_batch, n, n) Tensor.
    """

    id_matrix = b_mat.new_ones(b_mat.size(-1)).diag().expand_as(b_mat).cuda()
    b_inv, _ = torch.gesv(id_matrix, b_mat)
    
    return b_inv 
示例7
def update_state(self, K_prefix, S, dy):
		dx = K_prefix.mm(torch.gesv(dy, S)[0]).squeeze(1) # K*dy
		self.x[:3] += dx[:3]
		self.x[3:6] += (SO3.exp(dx[3:6]).dot(SO3.from_rpy(self.x[3:6]))).to_rpy()
		self.x[6:9] += dx[6:9] 
示例8
def solve_interpolation(train_points, train_values, order, regularization_weight):
    device = train_points.device
    b, n, d = train_points.shape
    k = train_values.shape[-1]

    c = train_points
    f = train_values.float()

    matrix_a = phi(cross_squared_distance_matrix(c, c), order).unsqueeze(0)  # [b, n, n]

    # Append ones to the feature values for the bias term in the linear model.
    ones = torch.ones(1, dtype=train_points.dtype, device=device).view([-1, 1, 1])
    matrix_b = torch.cat((c, ones), 2).float()  # [b, n, d + 1]

    # [b, n + d + 1, n]
    left_block = torch.cat((matrix_a, torch.transpose(matrix_b, 2, 1)), 1)

    num_b_cols = matrix_b.shape[2]  # d + 1

    # In Tensorflow, zeros are used here. Pytorch solve fails with zeros
    # for some reason we don't understand.
    # So instead we use very tiny randn values (variance of one, zero mean)
    # on one side of our multiplication.
    lhs_zeros = torch.randn((b, num_b_cols, num_b_cols), device=device) / 1e10
    right_block = torch.cat((matrix_b, lhs_zeros), 1)  # [b, n + d + 1, d + 1]
    lhs = torch.cat((left_block, right_block), 2)  # [b, n + d + 1, n + d + 1]

    rhs_zeros = torch.zeros(
        (b, d + 1, k), dtype=train_points.dtype, device=device
    ).float()
    rhs = torch.cat((f, rhs_zeros), 1)  # [b, n + d + 1, k]

    # Then, solve the linear system and unpack the results.
    X, LU = torch.gesv(rhs, lhs)
    w = X[:, :n, :]
    v = X[:, n:, :]

    return w, v 
示例9
def matrix_solve(self, matrix, rhs, adjoint=None):
        import ipdb; ipdb.set_trace()
        return torch.gesv(rhs, matrix)[0]

    # Theano interface 
示例10
def b_inv(b_mat):
    ''' Performs batch matrix inversion.

    Arguments:
        b_mat: the batch of matrices that should be inverted
    '''

    eye = b_mat.new_ones(b_mat.size(-1)).diag().expand_as(b_mat)
    b_inv, _ = torch.gesv(eye, b_mat)
    return b_inv 
示例11
def apply_affine2point(points, theta, shape):
  assert points.size(0) == 3, 'invalid points shape : {:}'.format(points.size())
  with torch.no_grad():
    ok_points = points[2,:] == 1
    assert torch.sum(ok_points).item() > 0, 'there is no visiable point'
    points[:2,:] = normalize_points(shape, points[:2,:])

    norm_trans_points = ok_points.unsqueeze(0).repeat(3, 1).float()

    trans_points, ___ = torch.gesv(points[:, ok_points], theta)

    norm_trans_points[:, ok_points] = trans_points
    
  return norm_trans_points 
示例12
def solve_interpolation(train_points, train_values, order, regularization_weight):
    device = train_points.device
    b, n, d = train_points.shape
    k = train_values.shape[-1]

    c = train_points
    f = train_values.float()

    matrix_a = phi(cross_squared_distance_matrix(c, c), order).unsqueeze(0)  # [b, n, n]

    # Append ones to the feature values for the bias term in the linear model.
    ones = torch.ones(1, dtype=train_points.dtype, device=device).view([-1, 1, 1])
    matrix_b = torch.cat((c, ones), 2).float()  # [b, n, d + 1]

    # [b, n + d + 1, n]
    left_block = torch.cat((matrix_a, torch.transpose(matrix_b, 2, 1)), 1)

    num_b_cols = matrix_b.shape[2]  # d + 1

    # In Tensorflow, zeros are used here. Pytorch solve fails with zeros for some reason we don't understand.
    # So instead we use very tiny randn values (variance of one, zero mean) on one side of our multiplication.
    lhs_zeros = torch.randn((b, num_b_cols, num_b_cols), device=device) / 1e10
    right_block = torch.cat((matrix_b, lhs_zeros), 1)  # [b, n + d + 1, d + 1]
    lhs = torch.cat((left_block, right_block), 2)  # [b, n + d + 1, n + d + 1]

    rhs_zeros = torch.zeros((b, d + 1, k), dtype=train_points.dtype, device=device).float()
    rhs = torch.cat((f, rhs_zeros), 1)  # [b, n + d + 1, k]

    # Then, solve the linear system and unpack the results.
    X, LU = torch.gesv(rhs, lhs)
    w = X[:, :n, :]
    v = X[:, n:, :]

    return w, v 
示例13
def b_inv(b_mat):
    ''' Performs batch matrix inversion.

    Arguments:
        b_mat: the batch of matrices that should be inverted
    '''

    eye = b_mat.new_ones(b_mat.size(-1)).diag().expand_as(b_mat)
    b_inv, _ = torch.gesv(eye, b_mat)
    return b_inv 
示例14
def batched_mat_inv(mat_tensor):
    """
    [TESTED, file: valid_banet_mat_inv.py]
    Returns the inverses of a batch of square matrices.
    alternative implementation: torch.gesv() with LU decomposition
    :param mat_tensor: Batched Square Matrix tensor with dim (N, m, m). N is the batch size, m is the size of square mat
    :return batched inverse square mat
    """
    n = mat_tensor.size(-1)
    flat_bmat_inv = torch.stack([m.inverse() for m in mat_tensor.view(-1, n, n)])
    return flat_bmat_inv.view(mat_tensor.shape) 
示例15
def optimization_step(A, Y, Z, beta, apply_cg, mode='prox_cg1'):
    """
    Optimization step for several different modes:
        - 'prox_exact' takes an exact proximal step
        - 'prox_cgN' takes an approximate proximal step, generated by N conjugate gradient steps
        - 'gradient' performs a gradient step; this method recovers classical SGD

    Taking a proximal step amounts to solving:

        argmin_{X} 1/2 ||AX - Y||^2 + beta/2 ||X - Z||^2,

    i.e. solving a linear system (exactly or approximately).

    The method argument 'apply_cg' is the one taken from the respective ProxProp module.
    """
    apply_A = partial(apply_cg, A)

    if mode == 'prox_exact':
        AA = A.t().mm(A)
        I = torch.eye(A.size(1)).type_as(A)
        A_tilde = AA + beta * I
        b_tilde = A.t().mm(Y) + beta*Z
        X, _ = torch.gesv(b_tilde, A_tilde)
    elif mode[:7] == 'prox_cg':
        num_prox_steps = int(mode[7:])
        apply_A_tilde = lambda x : apply_A(x, 3) + beta*x
        b_tilde = apply_A(Y,2) + beta * Z
        res = conjugate_gradient_block(apply_A_tilde, b_tilde, x0=Z, maxit=num_prox_steps)
        X = res[0]
    elif mode == 'gradient':
        X = Z - (apply_A(apply_A(Z,1) - Y,2))
    else:
        raise ValueError('The optimization mode "{}" you have specified is not valid.'.format(mode))
    return X 
示例16
def finalize(self, train=False):
        """
        Finalize training with LU factorization or Pseudo-inverse
        """
        if self.learning_algo == 'inv':
            if not self.averaged:
                ridge_xTx = self.xTx + self.ridge_param * torch.eye(self.input_dim + self.with_bias, dtype=self.dtype)
                inv_xTx = ridge_xTx.inverse()
                self.w_out.data = torch.mm(inv_xTx, self.xTy).data
            else:
                # Average
                self.xTx = self.xTx / self.n_samples
                self.xTy = self.xTy / self.n_samples

                # Algo
                ridge_xTx = self.xTx + self.ridge_param * torch.eye(self.input_dim + self.with_bias, dtype=self.dtype)
                inv_xTx = ridge_xTx.inverse()
                self.w_out.data = torch.mm(inv_xTx, self.xTy).data
            # end if
        else:
            self.w_out.data = torch.gesv(self.xTy, self.xTx + torch.eye(self.esn_cell.output_dim).mul(self.ridge_param)).data
        # end if

        # Not in training mode anymore
        self.train(train)
    # end finalize

    ###############################################
    # PRIVATE
    ###############################################

    # Add constant 
示例17
def ir_logistic(self, X, w0, y_inner):
        # iteration 0
        eta = w0  # + zeros
        mu = torch.sigmoid(eta)
        s = mu * (1 - mu)
        z = eta + (y_inner - mu) / s
        S = torch.diag(s)
        # Woodbury with regularization
        w_ = mm(t(X, 0, 1), inv(mm(X, t(X, 0, 1)) + self.lambda_(inv(S))))
        z_ = t(z.unsqueeze(0), 0, 1)
        w = mm(w_, z_)
        # it 1...N
        for i in range(self.iterations - 1):
            eta = w0 + mm(X, w).squeeze(1)
            mu = torch.sigmoid(eta)
            s = mu * (1 - mu)
            z = eta + (y_inner - mu) / s
            S = torch.diag(s)
            z_ = t(z.unsqueeze(0), 0, 1)
            if not self.linsys:
                w_ = mm(t(X, 0, 1), inv(mm(X, t(X, 0, 1)) + self.lambda_(inv(S))))
                w = mm(w_, z_)
            else:
                A = mm(X, t(X, 0, 1)) + self.lambda_(inv(S))
                w_, _ = gesv(z_, A)
                w = mm(t(X, 0, 1), w_)

        return w 
示例18
def rr_standard(self, x, n_way, n_shot, I, yrr_binary, linsys):
        x /= np.sqrt(n_way * n_shot * self.n_augment)

        if not linsys:
            w = mm(mm(inv(mm(t(x, 0, 1), x) + self.lambda_rr(I)), t(x, 0, 1)), yrr_binary)
        else:
            A = mm(t_(x), x) + self.lambda_rr(I)
            v = mm(t_(x), yrr_binary)
            w, _ = gesv(v, A)

        return w 
示例19
def rr_woodbury(self, x, n_way, n_shot, I, yrr_binary, linsys):
        x /= np.sqrt(n_way * n_shot * self.n_augment)

        if not linsys:
            w = mm(mm(t(x, 0, 1), inv(mm(x, t(x, 0, 1)) + self.lambda_rr(I))), yrr_binary)
        else:
            A = mm(x, t_(x)) + self.lambda_rr(I)
            v = yrr_binary
            w_, _ = gesv(v, A)
            w = mm(t_(x), w_)

        return w