
def idct_8x8(image):
  alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
  alpha = torch.FloatTensor(np.outer(alpha, alpha)).cuda()
  image = image * alpha

  tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
  for x, y, u, v in itertools.product(range(8), repeat=4):
    tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos(
        (2 * v + 1) * y * np.pi / 16)
  # result = 0.25 * torch.tensordot(image, torch.as_tensor(tensor, device="cuda"), dims=2) + 128
  result = 0.25 * tensordot_pytorch(image, torch.as_tensor(tensor, device="cuda"), dims=2) + 128
  return result

# -3. Block joining 
def get_obs(Asymm, H, Sx, Sy, Sz, C, E ):
    # A(phy,u,l,d,r), C(d,r), E(u,r,d)
    Da = Asymm.size()
    Td = torch.einsum('mefgh,nabcd->eafbgchdmn',(Asymm,Asymm)).contiguous().view(Da[1]**2, Da[2]**2, Da[3]**2, Da[4]**2, Da[0], Da[0])
    #print( torch.dist( Td, Td.permute(0,3,2,1,4,5) ) )    # test left-right reflection symmetry of Td

    CE = torch.tensordot(C,E,([1],[0]))         # C(1d)E(dga)->CE(1ga)
    EL = torch.tensordot(E,CE,([2],[0]))        # E(2e1)CE(1ga)->EL(2ega)  use E(2e1) == E(1e2) 
    EL = torch.tensordot(EL,Td,([1,2],[1,0]))   # EL(2ega)T(gehbmn)->EL(2ahbmn)
    EL = torch.tensordot(EL,CE,([0,2],[0,1]))   # EL(2ahbmn)CE(2hc)->EL(abmnc), use CE(2hc) == CE(1ga) 
    Rho = torch.tensordot(EL,EL,([0,1,4],[0,1,4])).permute(0,2,1,3).contiguous().view(Da[0]**2,Da[0]**2)
    # print( (Rho-Rho.t()).norm() )
    Rho = 0.5*(Rho + Rho.t())
    Tnorm = Rho.trace()
    Energy =,H).trace()/Tnorm
    Mx =,Sx).trace()/Tnorm
    My =,Sy).trace()/Tnorm
    Mz =,Sz).trace()/Tnorm
    #print("Tnorm = %g, Energy = %g " % (Tnorm.item(), Energy.item()) )

    return Energy, Mx, My, Mz 
def forward(self, inputs):
        embeds_vec_list = inputs
        row = []
        col = []

        for r, c in itertools.combinations(embeds_vec_list, 2):

        p =, dim=1)
        q =, dim=1)
        inner_product = p * q

        bi_interaction = inner_product
        attention_temp = F.relu(torch.tensordot(
            bi_interaction, self.attention_W, dims=([-1], [0])) + self.attention_b)

        self.normalized_att_score = F.softmax(torch.tensordot(
            attention_temp, self.projection_h, dims=([-1], [0])), dim=1)
        attention_output = torch.sum(
            self.normalized_att_score * bi_interaction, dim=1)

        attention_output = self.dropout(attention_output)  # training

        afm_out = torch.tensordot(
            attention_output, self.projection_p, dims=([-1], [0]))
        return afm_out 
def forward(self, inputs):

        if len(inputs.shape) != 3:
            raise ValueError(
                "Unexpected inputs dimensions %d, expect to be 3 dimensions" % (len(inputs.shape)))

        querys = torch.tensordot(inputs, self.W_Query,
                                 dims=([-1], [0]))  # None F D*head_num
        keys = torch.tensordot(inputs, self.W_key, dims=([-1], [0]))
        values = torch.tensordot(inputs, self.W_Value, dims=([-1], [0]))

        # head_num None F D

        querys = torch.stack(torch.split(
            querys, self.att_embedding_size, dim=2))
        keys = torch.stack(torch.split(keys, self.att_embedding_size, dim=2))
        values = torch.stack(torch.split(
            values, self.att_embedding_size, dim=2))
        inner_product = torch.einsum(
            'bnik,bnjk->bnij', querys, keys)  # head_num None F F

        self.normalized_att_scores = F.softmax(
            inner_product, dim=-1)  # head_num None F F
        result = torch.matmul(self.normalized_att_scores,
                              values)  # head_num None F D

        result =, 1, ), dim=-1)
        result = torch.squeeze(result, dim=0)  # None F D*head_num
        if self.use_res:
            result += torch.tensordot(inputs, self.W_Res, dims=([-1], [0]))
        result = F.relu(result)

        return result 
def forward(self, inputs):
        x_0 = inputs.unsqueeze(2)
        x_l = x_0
        for i in range(self.layer_num):
            xl_w = torch.tensordot(x_l, self.kernels[i], dims=([1], [0]))
            dot_ = torch.matmul(x_0, xl_w)
            x_l = dot_ + self.bias[i] + x_l
        x_l = torch.squeeze(x_l, dim=2)
        return x_l 
def tensordot(x, y, dims):
    Wrapper around :func:`torch.tensordot` or :func:`np.tensordot`
    to operate on real-valued Funsors.

    Note this operates only on the ``output`` tensor. To perform sum-product
    contractions on named dimensions, instead use ``+`` and

    Arguments should satisfy::

        len(x.shape) >= dims
        len(y.shape) >= dims
        dims == 0 or x.shape[-dims:] == y.shape[:dims]

    :param Funsor x: A left hand argument.
    :param Funsor y: A y hand argument.
    :param int dims: The number of dimension of overlap of output shape.
    :rtype: Funsor
    assert dims >= 0
    assert len(x.shape) >= dims
    assert len(y.shape) >= dims
    assert dims == 0 or x.shape[-dims:] == y.shape[:dims]
    x_start, x_end = 0, len(x.output.shape)
    y_start = x_end - dims
    y_end = y_start + len(y.output.shape)
    symbols = 'abcdefghijklmnopqrstuvwxyz'
    equation = '{},{}->{}'.format(symbols[x_start:x_end],
                                  symbols[x_start:y_start] + symbols[x_end:y_end])
    return Einsum(equation, (x, y)) 
def _numeric_tensordot(x, y, dim):
    if get_backend() == "torch":
        import torch

        return torch.tensordot(x, y, dim)
        return np.tensordot(x, y, axes=dim) 
def test_tensor_tensordot(x_shape, xy_shape, y_shape):
    x = randn(x_shape + xy_shape)
    y = randn(xy_shape + y_shape)
    dim = len(xy_shape)
    actual = tensordot(Tensor(x), Tensor(y), dim)
    expected = Tensor(_numeric_tensordot(x, y, dim))
    assert_close(actual, expected, atol=1e-5, rtol=None) 
def rgb_to_ycbcr(image):
  matrix = np.array(
      [[65.481, 128.553, 24.966],
       [-37.797, -74.203, 112.],
       [112., -93.786, -18.214]],
      dtype=np.float32).T / 255
  shift = torch.as_tensor([16., 128., 128.], device="cuda")

  # result = torch.tensordot(image, torch.as_tensor(matrix, device="cuda"), dims=1) + shift
  result = tensordot_pytorch(image, matrix, dims=1) + shift
  return result 
def rgb_to_ycbcr_jpeg(image):
  matrix = np.array(
      [[0.299, 0.587, 0.114],
       [-0.168736, -0.331264, 0.5],
       [0.5, -0.418688, -0.081312]],
  shift = torch.as_tensor([0., 128., 128.], device="cuda")

  # result = torch.tensordot(image, torch.as_tensor(matrix, device="cuda"), dims=1) + shift
  result = tensordot_pytorch(image, torch.as_tensor(matrix, device='cuda'), dims=1) + shift
  return result

# 2. Chroma subsampling 
def dct_8x8(image):
  image = image - 128
  tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
  for x, y, u, v in itertools.product(range(8), repeat=4):
    tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos(
        (2 * y + 1) * v * np.pi / 16)
  alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
  scale = torch.FloatTensor(np.outer(alpha, alpha) * 0.25).cuda()
  #result = scale * torch.tensordot(image, torch.as_tensor(tensor, device="cuda"), dims=2)
  result = scale * tensordot_pytorch(image, torch.as_tensor(tensor, device="cuda"), dims=2)
  return result 
def ycbcr_to_rgb(image):
  matrix = np.array(
      [[298.082, 0, 408.583],
       [298.082, -100.291, -208.120],
       [298.082, 516.412, 0]],
      dtype=np.float32).T / 256
  shift = torch.as_tensor([-222.921, 135.576, -276.836], device="cuda")

  # result = torch.tensordot(image, torch.tensor(matrix, device="cuda"), dims=1) + shift
  result = tensordot_pytorch(image, torch.tensor(matrix, device="cuda"), dims=1) + shift
  return result 
def test_tensordot():
  backend = pytorch_backend.PyTorchBackend()
  a = backend.convert_to_tensor(2 * np.ones((2, 3, 4)))
  b = backend.convert_to_tensor(np.ones((2, 3, 4)))
  actual = backend.tensordot(a, b, ((1, 2), (1, 2)))
  expected = np.array([[24.0, 24.0], [24.0, 24.0]])
  np.testing.assert_allclose(expected, actual) 
def test_eigsh_lanczos_0():
  #this test should just not crash
  dtype = torch.float64
  backend = pytorch_backend.PyTorchBackend()
  D = 4
  init = backend.randn((2, 2, 2), dtype=dtype)
  tmp = backend.randn((8, 8), dtype=dtype)
  H = tmp + backend.transpose(backend.conj(tmp), (1, 0))
  H = H.reshape([2, 2, 2, 2, 2, 2])

  def mv(x, mat):
    return torch.tensordot(mat, x, ([0, 3, 5], [2, 0, 1])).permute([2, 0, 1])

  backend.eigsh_lanczos(mv, [H], init, num_krylov_vecs=D) 
def forward(self, inputs, embed=True):
    if embed:
      return torch.nn.functional.embedding(inputs, self.w)
      return torch.tensordot(inputs, self.w.t(), 1) + self.b 
def sample(self, p, z=None):
        """Input p to be shaped [T,B,A,P] or [B,A,P], A: number of actions, P:
        number of atoms.  Optional input z is domain of atom-values, shaped
        [P].  Vector epsilon of lenght B will apply across Batch dimension."""
        q = torch.tensordot(p, z or self.z, dims=1)
        return super().sample(q) 
def reweight(self, msa1hot):
        # Reweight
        seqlen = msa1hot.size(1)
        id_min = seqlen * self.msa_cutoff
        id_mtx = torch.tensordot(msa1hot, msa1hot, [[1, 2], [1, 2]])
        id_mask = id_mtx > id_min
        weights = 1.0 / id_mask.float().sum(-1)
        return weights 
def reweight(self, msa1hot, eps=1e-9):
        # Reweight
        seqlen = msa1hot.size(2)
        id_min = seqlen * self.msa_cutoff
        id_mtx = torch.stack([torch.tensordot(el, el, [[1, 2], [1, 2]]) for el in msa1hot], 0)
        id_mask = id_mtx > id_min
        weights = 1.0 / (id_mask.type_as(msa1hot).sum(-1) + eps)
        return weights 
def calculate_reward(
        slots: SlateSlots,
        rewards: Optional[SlateSlotValues] = None,
        slot_values: Optional[SlateSlotValues] = None,
        slot_weights: Optional[SlateSlotValues] = None,
    ) -> float:
        if slot_values is None:
            assert rewards is not None
            slot_values = self.slot_values(rewards)
        values =
        if slot_weights is None:
            slot_weights = self.slot_weights(slots)
        weights =
        return torch.tensordot(values, weights, dims=([0], [0])).item() 
def _evaluate_sample(self, sample: LogSample) -> Optional[EstimatorSampleResult]:
        log_slot_expects = sample.log_slot_item_expectations(sample.context.slots)
        if log_slot_expects is None:
            logger.warning("  Log slot distribution not available")
            return None
        tgt_slot_expects = sample.tgt_slot_expectations(sample.context.slots)
        if tgt_slot_expects is None:
            logger.warning("  Target slot distribution not available")
            return None
        slate_size = len(sample.context.slots)
        slot_weights = sample.slot_weights
        if slot_weights is None:
            slot_weights = SlateSlotValues(torch.ones(slate_size, dtype=torch.double))
        weights =
        if sample.slot_probabilities is not None:
            weights *= sample.slot_probabilities.values
        h = torch.zeros(slate_size, dtype=torch.double, device=self._device)
        p = torch.zeros(slate_size, dtype=torch.double, device=self._device)
        i = 0
        for slot, item in sample.log_slate:
            h[i] = tgt_slot_expects[slot][item]
            p[i] = log_slot_expects[slot][item]
            i += 1
        nu = torch.tensordot(h, weights, dims=([0], [0]))
        de = torch.tensordot(p, weights, dims=([0], [0]))
        if nu == de:
            weight = 1.0
        elif nu == 0:
            weight = 0.0
        elif de == 0:
            return None
            weight = self._weight_clamper(nu / de)
        return EstimatorSampleResult(
            sample.log_reward * weight,

    # pyre-fixme[14]: `evaluate` overrides method defined in `Estimator` inconsistently. 
def soft_embedding_lookup(embedding, soft_ids):
    r"""Transforms soft ids (e.g., probability distribution over ids) into
    embeddings, by mixing the embedding vectors with the soft weights.

        embedding: A Tensor of shape ``[num_classes] + embedding-dim``
            containing the embedding vectors. Embedding can have dimensionality
            > 1, i.e., :attr:`embedding` can be of shape
            ``[num_classes, emb_dim_1, emb_dim_2, ...]``
        soft_ids: A Tensor of weights (probabilities) used to mix the
            embedding vectors.

        A Tensor of shape ``shape(soft_ids)[:-1] + shape(embedding)[1:]``. For
        example, if ``shape(soft_ids) = [batch_size, max_time, vocab_size]``
        and ``shape(embedding) = [vocab_size, emb_dim]``, then the returned
        tensor has shape ``[batch_size, max_time, emb_dim]``.


        softmax = torch.nn.Softmax()
        decoder_outputs, ... = decoder(...)
        soft_seq_emb = soft_embedding_lookup(
            embedding, softmax(decoder_outputs.logits))
    return torch.tensordot(soft_ids, embedding, dims=([-1], [0])) 
def forward(self):
        self.output =
        x = self.output
        [n, l , classes] = x.size()
        x = x.view(n * l, classes)

        # print(x)
        self.loss_ce = F.cross_entropy(x,
        if self.opt.entropy_loss_coeff > 0:
            S = F.softmax(x, dim=1) * F.log_softmax(x, dim=1)
            S = -1.0 * S.mean()
            self.loss_ce += self.opt.entropy_loss_coeff * S
        self.metric_accuracy = (torch.argmax(x,1) ==

        #TODO: implement humaneness_reg maybe
        # problem is we don't have past notes available in input, so need to do that differently
        # just use output I guess :P
        # step_size = self.opt.step_size
        # humaneness_delta = constants.HUMAN_DELTA
        # window_size = int(humaneness_delta/step_size)
        # receptive_field =
        # notes = (torch.argmax(input[:,-5:,receptive_field//2-(window_size):receptive_field//2],1)==4).float()
        # distance_factor = torch.tensor(np.exp(-2*np.arange(window_size,0,-1)/window_size)).float().cuda()
        # if self.opt.entropy_loss_coeff > 0:
        #     weights = torch.tensordot(notes,distance_factor,dims=1)
        #     humaneness_reg = F.cross_entropy(x,torch.zeros(weights.shape).long().cuda(), reduction='none')
        #     humaneness_reg =, weights)
        #     self.loss_humaneness_reg = humaneness_reg
        #     # self.loss_humaneness_reg = 0
        #     self.loss_total = self.loss_ce + self.opt.humaneness_reg_coeff * self.loss_humaneness_reg
        # else:
        #     self.loss_humaneness_reg = 0
        #     self.loss_total = self.loss_ce
        self.loss_humaneness_reg = 0
        self.loss_total = self.loss_ce 
def forward(self):
        self.output =
        x = self.output
        [n, channels, classes, l] = x.size()
        x = x.transpose(1, 3).contiguous()
        x = x.view(n * l * channels, classes)

        self.loss_ce = F.cross_entropy(x,
        if self.opt.entropy_loss_coeff > 0:
            S = F.softmax(x, dim=1) * F.log_softmax(x, dim=1)
            S = -1.0 * S.mean()
            self.loss_ce += self.opt.entropy_loss_coeff * S
        self.metric_accuracy = (torch.argmax(x,1) ==

        step_size = self.opt.step_size
        humaneness_delta = constants.HUMAN_DELTA
        window_size = int(humaneness_delta/step_size)
        # print(humaneness_reg.shape)
        receptive_field =
        # weights = torch.sum(torch.argmax(self.input[:,-5:,receptive_field//2-(window_size-1):receptive_field//2],1)==4,1).float()
        notes = (torch.argmax(self.input[:,-5:,receptive_field//2-(window_size):receptive_field//2],1)==4).float()
        distance_factor = torch.tensor(np.exp(-2*np.arange(window_size,0,-1)/window_size)).float().cuda()
        # print(notes.shape, distance_factor.shape)
        weights = torch.tensordot(notes,distance_factor,dims=1)
        # print()
        # print(self.input[:,-5:,receptive_field//2-(window_size-1):receptive_field//2].shape)
        # self.loss_humaneness_reg = F.relu(humaneness_reg-1).mean()
        # humaneness_reg = -F.cross_entropy(x,torch.ones(weights.shape).long().cuda(), reduction='none')
        humaneness_reg = F.cross_entropy(x,torch.zeros(weights.shape).long().cuda(), reduction='none')
        humaneness_reg =, weights)
        self.loss_humaneness_reg = humaneness_reg
        self.loss_total = self.loss_ce + self.opt.humaneness_reg_coeff * self.loss_humaneness_reg 
def _get_torch_and_device():
    global _TORCH_DEVICE

    if _TORCH_DEVICE is None:
        import torch
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        _TORCH_DEVICE = torch, device
        _TORCH_HAS_TENSORDOT = hasattr(torch, 'tensordot')

    return _TORCH_DEVICE 
def forward(self, beta, pose, trans=None, simplify=False):
        """This module takes betas and poses in a batched manner.
        A pose is 3 * K + 3 (= self.kintree_table.shape[1] * 3) parameters, where K is the number of joints.
        A beta is a vector of size self.shapedirs.shape[2], that parameterizes the body shape.
        Since this is batched, multiple betas and poses should be concatenated along zeroth dimension.
        See for more info.
        batch_size = beta.shape[0]  # Size of zeroth dimension.

        # The body shape is decomposed with principal component analysis from many subjects,
        #  where self.v_template is the average value. Then shapedirs is a subset of the orthogonal directions, and
        #  a the betas are the values when the subject is projected onto these. v_shaped is the "restored" subject.
        v_shaped = torch.tensordot(beta, self.shapedirs, dims=([1], [2])) + self.v_template

        # We turn the rotation vectors into rotation matrices.
        R_cube = self.rodrigues(pose.reshape(-1, 1, 3)).reshape(batch_size, -1, 3, 3)
        J = self.regress_joints(v_shaped)  # Joints in T-pose (for limb lengths)

        if not simplify:
            # Add pose blend shapes. (How joint angles morphs the surface)
            # Now calculate how joints affects the body shape.
            lrotmin = R_cube[:, 1:] - self.eye
            lrotmin = lrotmin.reshape(batch_size, -1)
            v_shaped += torch.tensordot(lrotmin, self.posedirs, dims=([1], [2]))

        # Now we have the un-posed body shape. Convert to homogeneous coordinates.
        rest_shape_h =, v_shaped.new_ones(1).expand(*v_shaped.shape[:-1], 1)), 2)

        G = [self.rotate_translate(R_cube[:, 0], J[:, 0])]
        for i in range(1, self.kintree_table.shape[1]):
                    self.rotate_translate(R_cube[:, i], J[:, i] - J[:, self.parent[i]])))
        G = torch.stack(G, 1)
        Jtr = G[..., :4, 3].clone()
        G = G - self.pack(torch.matmul(G,[J, J.new_zeros(1).expand(*J.shape[:2], 1)], dim=2).unsqueeze(-1)))

        # T = torch.tensordot(self.weights, G, dims=([1], [1]))
        # v = T.reshape(-1, 4, 4).bmm(rest_shape_h.reshape(-1, 4, 1)).reshape(batch_size, -1, 4)

        # Two next lines are a memory bottleneck.
        T = torch.tensordot(G, self.weights, dims=([1], [1])).permute(0, 3, 1, 2)

        v = torch.matmul(T, torch.reshape(rest_shape_h, (batch_size, -1, 4, 1))).reshape(batch_size, -1, 4)

        if trans is not None:
            trans = trans.unsqueeze(1)
            v[..., :3] += trans
            Jtr[..., :3] += trans

        return v, Jtr 
def forward(self, heatmap, points, target_hm, target_points):
        if(self.loss_type == 'l2_softargmax' or self.loss_type == 'l2_sm'):
            mse_loss = (points - target_points) ** 2
            location_loss = mse_loss.sum(2).sum(1).mean()
        elif(self.loss_type == 'l2_heatmap' or self.loss_type == 'l2_hm'):
            mse_loss = (heatmap - target_hm) ** 2
            location_loss = mse_loss.sum(3).sum(2).sum(1).mean()
        elif(self.loss_type == 'l1_softargmax' or self.loss_type == 'l1_sm'):
            l1_loss = torch.abs(points - target_points)
            location_loss = l1_loss.sum(2).sum(1).mean()
            print("Did not recognize loss function selection!")

        if self.include_geo:
            # Loss on co-linearity of points along side of cone
            v53 = F.normalize(points[:, 5] - points[:, 3], dim=1)
            v31 = F.normalize(points[:, 3] - points[:, 1], dim=1)
            vA = 1.0 - torch.tensordot(v31, v53, dims=([1], [1]))
            v10 = F.normalize(points[:, 1] - points[:, 0], dim=1)
            vB = 1.0 - torch.tensordot(v10, v31, dims=([1], [1]))

            v64 = F.normalize(points[:, 6] - points[:, 4], dim=1)
            v42 = F.normalize(points[:, 4] - points[:, 2], dim=1)
            vC = 1.0 - torch.tensordot(v64, v42, dims=([1], [1]))

            v20 = F.normalize(points[:, 2] - points[:, 0], dim=1)
            vD = 1.0 - torch.tensordot(v42, v20, dims=([1], [1]))
            # Loss on horizontals on cones (color boundaries)
            h21 = F.normalize(points[:, 2] - points[:, 1], dim=1)
            h43 = F.normalize(points[:, 4] - points[:, 3], dim=1)
            hA = 1.0 - torch.tensordot(h43, h21, dims=([1], [1]))

            h65 = F.normalize(points[:, 6] - points[:, 5], dim=1)
            hB = 1.0 - torch.tensordot(h65, h43, dims=([1], [1]))
            geo_loss = self.geo_loss_gamma_horz * (hA + hB).mean() / 2 + self.geo_loss_gamma_vert * (vA + vB + vC + vD).mean() / 4
            geo_loss = torch.tensor(0)
        #print('Geo Loss:      ' + str(geo_loss.item()))
        #print('Location Loss: ' + str(location_loss.item()))
        return location_loss, geo_loss, location_loss+geo_loss 
def renormalize(*tensors):
    # T(up,left,down,right), u=up, l=left, d=down, r=right
    # C(d,r), EL(u,r,d), EU(l,d,r)

    C, E, T, chi = tensors

    dimT, dimE = T.shape[0], E.shape[0]
    D_new = min(dimE*dimT, chi)

    # step 1: contruct the density matrix Rho
    Rho = torch.tensordot(C,E,([1],[0]))        # C(ef)*EU(fga)=Rho(ega)
    Rho = torch.tensordot(Rho,E,([0],[0]))      # Rho(ega)*EL(ehc)=Rho(gahc)
    Rho = torch.tensordot(Rho,T,([0,2],[0,1]))  # Rho(gahc)*T(ghdb)=Rho(acdb)
    Rho = Rho.permute(0,3,1,2).contiguous().view(dimE*dimT, dimE*dimT)  # Rho(acdb)->Rho(ab;cd)

    Rho = Rho+Rho.t()
    Rho = Rho/Rho.norm()

    # step 2: Get Isometry P
    U, S, V = svd(Rho)
    truncation_error = S[D_new:].sum()/S.sum()
    P = U[:, :D_new] # projection operator
    #can also do symeig since Rho is symmetric 
    #S, U = symeig(Rho)
    #sorted, indices = torch.sort(S.abs(), descending=True)
    #truncation_error = sorted[D_new:].sum()/sorted.sum()
    #S = S[indices][:D_new]
    #P = U[:, indices][:, :D_new] # projection operator

    # step 3: renormalize C and E
    C = (P.t() @ Rho @ P) #C(D_new, D_new)

    ## EL(u,r,d)
    P = P.view(dimE,dimT,D_new)
    E = torch.tensordot(E, P, ([0],[0]))  # EL(def)P(dga)=E(efga)
    E = torch.tensordot(E, T, ([0,2],[1,0]))  # E(efga)T(gehb)=E(fahb)
    E = torch.tensordot(E, P, ([0,2],[0,1]))  # E(fahb)P(fhc)=E(abc)

    # step 4: symmetrize C and E
    C = 0.5*(C+C.t())
    E = 0.5*(E + E.permute(2, 1, 0))

    return C/C.norm(), E, S.abs()/S.abs().max(), truncation_error 
def loss(self, samples):
        Computes the Distributional Q-learning loss, based on projecting the
        discounted rewards + target Q-distribution into the current Q-domain,
        with cross-entropy loss.  

        Returns loss and KL-divergence-errors for use in prioritization.

        delta_z = (self.V_max - self.V_min) / (self.agent.n_atoms - 1)
        z = torch.linspace(self.V_min, self.V_max, self.agent.n_atoms)
        # Makde 2-D tensor of contracted z_domain for each data point,
        # with zeros where next value should not be added.
        next_z = z * ( ** self.n_step_return)  # [P']
        next_z = torch.ger(1 - samples.done_n.float(), next_z)  # [B,P']
        ret = samples.return_.unsqueeze(1)  # [B,1]
        next_z = torch.clamp(ret + next_z, self.V_min, self.V_max)  # [B,P']

        z_bc = z.view(1, -1, 1)  # [1,P,1]
        next_z_bc = next_z.unsqueeze(1)  # [B,1,P']
        abs_diff_on_delta = abs(next_z_bc - z_bc) / delta_z
        projection_coeffs = torch.clamp(1 - abs_diff_on_delta, 0, 1)  # Most 0.
        # projection_coeffs is a 3-D tensor: [B,P,P']
        # dim-0: independent data entries
        # dim-1: base_z atoms (remains after projection)
        # dim-2: next_z atoms (summed in projection)

        with torch.no_grad():
            target_ps =*samples.target_inputs)  # [B,A,P']
            if self.double_dqn:
                next_ps = self.agent(*samples.target_inputs)  # [B,A,P']
                next_qs = torch.tensordot(next_ps, z, dims=1)  # [B,A]
                next_a = torch.argmax(next_qs, dim=-1)  # [B]
                target_qs = torch.tensordot(target_ps, z, dims=1)  # [B,A]
                next_a = torch.argmax(target_qs, dim=-1)  # [B]
            target_p_unproj = select_at_indexes(next_a, target_ps)  # [B,P']
            target_p_unproj = target_p_unproj.unsqueeze(1)  # [B,1,P']
            target_p = (target_p_unproj * projection_coeffs).sum(-1)  # [B,P]
        ps = self.agent(*samples.agent_inputs)  # [B,A,P]
        p = select_at_indexes(samples.action, ps)  # [B,P]
        p = torch.clamp(p, EPS, 1)  # NaN-guard.
        losses = -torch.sum(target_p * torch.log(p), dim=1)  # Cross-entropy.

        if self.prioritized_replay:
            losses *= samples.is_weights

        target_p = torch.clamp(target_p, EPS, 1)
        KL_div = torch.sum(target_p *
            (torch.log(target_p) - torch.log(p.detach())), dim=1)
        KL_div = torch.clamp(KL_div, EPS, 1 / EPS)  # Avoid <0 from NaN-guard.

        if not self.mid_batch_reset:
            valid = valid_from_done(samples.done)
            loss = valid_mean(losses, valid)
            KL_div *= valid
            loss = torch.mean(losses)

        return loss, KL_div 
def tensordot(x, y, axes=2):
    """Simple translation of tensordot syntax to einsum.
    torch, _ = _get_torch_and_device()

        return torch.tensordot(x, y, dims=axes)

    xnd = x.ndimension()
    ynd = y.ndimension()

    # convert int argument to (list[int], list[int])
    if isinstance(axes, int):
        axes = range(xnd - axes, xnd), range(axes)

    # convert (int, int) to (list[int], list[int])
    if isinstance(axes[0], int):
        axes = (axes[0], ), axes[1]
    if isinstance(axes[1], int):
        axes = axes[0], (axes[1], )

    # initialize empty indices
    x_ix = [None] * xnd
    y_ix = [None] * ynd
    out_ix = []

    # fill in repeated indices
    available_ix = iter(_torch_symbols_base)
    for ax1, ax2 in zip(*axes):
        repeat = next(available_ix)
        x_ix[ax1] = repeat
        y_ix[ax2] = repeat

    # fill in the rest, and maintain output order
    for i in range(xnd):
        if x_ix[i] is None:
            leave = next(available_ix)
            x_ix[i] = leave
    for i in range(ynd):
        if y_ix[i] is None:
            leave = next(available_ix)
            y_ix[i] = leave

    # form full string and contract!
    einsum_str = "{},{}->{}".format(*map("".join, (x_ix, y_ix, out_ix)))
    return einsum(einsum_str, x, y)