
def deproject_depth_to_points(depth, grid, intrinsics_inv):     
    b, _, h, w = depth.size()
    # check 
    # need to return a one-dimensional tensor to use the matrix-vector product
    # as a result we reshape to [B, 3, H*W] in order to multiply the intrinsics matrix
    # with a 3x1 vector (u, v, 1)
    current_pixel_coords = ( # convert grid to appropriate dims for matrix multiplication
        grid # [1, 3, H, W] #grid[:,:,:h,:w]
        .expand(b, 3, h, w) # [B, 3, H, W]
        .reshape(b, 3, -1)  # [B, 3, H*W] := [B, 3, UV1]    
    p3d = ( # K_inv * [UV1] * depth
        (intrinsics_inv @ current_pixel_coords) # [B, 3, 3] * [B, 3, UV1]
        .reshape(b, 3, h, w) * # [B, 3, H, W]
        #.unsqueeze(1) # unsqueeze to tri-channel for element wise product
    ) # [B, 3, H, W]
    return p3d 
def RGBalbedoSHToLight(colorImg, albedoImg, SH, confidence_map):
    #remove non-zeros [now confidence_map is the more clean] 
    confidence_map[colorImg==0] = 0
    confidence_map[albedoImg==0] = 0
    id_non_not = confidence_map.nonzero()
    idx_non = torch.unbind(id_non_not, 1) # this only works for two dimesion
    colorImg_non = colorImg[idx_non]
    albedoImg_non = albedoImg[idx_non]
    #get the shadingImg element-wise divide
    shadingImg_non = torch.div(colorImg_non, albedoImg_non)    
    shadingImg_non2 = shadingImg_non.view(-1,1)

    #:means 9 channels [get the shading image]
    SH0 = SH[0,:,:]; SH0_non = SH0[idx_non]
    SH1 = SH[1,:,:]; SH1_non = SH1[idx_non]
    SH2 = SH[2,:,:]; SH2_non = SH2[idx_non]
    SH3 = SH[3,:,:]; SH3_non = SH3[idx_non]
    SH4 = SH[4,:,:]; SH4_non = SH4[idx_non]
    SH5 = SH[5,:,:]; SH5_non = SH5[idx_non]
    SH6 = SH[6,:,:]; SH6_non = SH6[idx_non]
    SH7 = SH[7,:,:]; SH7_non = SH7[idx_non]
    SH8 = SH[8,:,:]; SH8_non = SH8[idx_non]
    SH_NON = torch.stack([SH0_non, SH1_non, SH2_non, SH3_non, SH4_non, SH5_non, SH6_non, SH7_non, SH8_non], dim=-1)
    ## only use the first N soultions if M>N  A(M*N) B(N*K) X should (N*K)[use N if M appears] 
    ## torch.gels(B, A, out=None)  Tensor
    light, _ = torch.gels(shadingImg_non2, SH_NON)  
    light_9 = light[0:9] # use first 9

    return (light_9, SH) 
def contrastive_divergence(self, input_data):
        # input_data is 64 (batch size) by 784 (real valued pixels in input image)
        # =Positive phase==================================================
        # Positive phase = use 'clamped' visible unit states to sample hidden unit states.
        # sample_hidden() treats each real-valued pixel as a probability
        positive_hidden_probabilities   = self.sample_hidden(input_data) # 64 x 128 hidden units
        # use positive_hidden_probabilities to get sample of binary hidden unit states
        positive_hidden_activations = (positive_hidden_probabilities >= self._random_probabilities(self.num_hidden)).float() # BATCH_SIZE = 64 x 128 hidden units
        positive_associations = torch.matmul(input_data.t(), positive_hidden_activations)
        # print((positive_associations.shape)) # torch.Size([784, 128]) HIDDEN_UNITS = 128
        # .t() = transpose:
        # positive_associations measures correlation between visible and hidden unit states when visible units are  clamped to training data.
        # =Negative phase==================================================
        # Negative phase, initialise with final binary positive_hidden_activations
        hidden_activations = positive_hidden_activations # 64 x 128

        for step in range(self.k): # number of contrastive divergence steps
            visible_probabilities   = self.sample_visible(hidden_activations)
            hidden_probabilities    = self.sample_hidden(visible_probabilities)
            hidden_activations      = (hidden_probabilities >= self._random_probabilities(self.num_hidden)).float()

        negative_visible_probabilities  = visible_probabilities
        negative_hidden_probabilities   = hidden_probabilities
        # negative_associations measures correlation between visible and hidden unit states when visible units are not clamped to training data.
        negative_associations = torch.matmul(negative_visible_probabilities.t(), negative_hidden_probabilities)

        # Update weight change
        self.weights_momentum *= self.momentum_coefficient
        self.weights_momentum += (positive_associations - negative_associations)

        # Update visible bias terms
        self.visible_bias_momentum *= self.momentum_coefficient
        self.visible_bias_momentum += torch.sum(input_data - negative_visible_probabilities, dim=0)

        # Update hidden bias terms
        self.hidden_bias_momentum *= self.momentum_coefficient
        self.hidden_bias_momentum += torch.sum(positive_hidden_probabilities - negative_hidden_probabilities, dim=0)

        batch_size = input_data.size(0)

        self.weights        += self.weights_momentum * self.learning_rate / batch_size
        self.visible_bias   += self.visible_bias_momentum * self.learning_rate / batch_size
        self.hidden_bias    += self.hidden_bias_momentum * self.learning_rate / batch_size

        self.weights -= self.weights * self.weight_decay  # L2 weight decay

        # Compute reconstruction error
        error = torch.sum((input_data - negative_visible_probabilities)**2)

        return error
# ===========================================================

########## DEFINITIONS DONE ##########

########## LOAD DATASET ########## 
def __init__(self, args, env):
        Q(s,a) is the expected reward. Z is the full distribution from which Q is generated.
        Support represents the support of Z distribution (non-zero part of pdf).
        Z is represented with a fixed number of "atoms", which are pairs of values (x_i, p_i)
        composed by the discrete positions (x_i) equidistant along its support defined between
        Vmin-Vmax and the probability mass or "weight" (p_i) for that particular position.

        As an example, for a given (s,a) pair, we can represent Z(s,a) with 8 atoms as follows:

                   .        .     .
                .  |     .  |  .  |
                |  |  .  |  |  |  |  .
                |  |  |  |  |  |  |  |
           Vmin ----------------------- Vmax
        self.action_space = env.action_space
        self.num_atoms = args.num_atoms
        self.Vmin = args.V_min
        self.Vmax = args.V_max = torch.linspace(args.V_min, args.V_max, self.num_atoms).to(device=args.device)
        self.delta_z = (args.V_max - args.V_min) / (self.num_atoms - 1)
        self.batch_size = args.batch_size
        self.multi_step = args.multi_step =

        self.online_net = RainbowDQN(args, self.action_space).to(device=args.device)
        if args.model_path and os.path.isfile(args.model_path):
            When you call torch.load() on a file which contains GPU tensors, those tensors will be 
            loaded to GPU by default. You can call torch.load(.., map_location=’cpu’) and then 
            load_state_dict() to avoid GPU RAM surge when loading a model checkpoint.
            self.online_net.load_state_dict(torch.load(args.model_path, map_location='cpu'))

        self.target_net = RainbowDQN(args, self.action_space).to(device=args.device)
        for param in self.target_net.parameters():
            param.requires_grad = False

        self.optimiser = optim.Adam(self.online_net.parameters(),, eps=args.adam_eps)