
def map_pool(output_size, dtype=np.float32):
    def _thunk(obs_space):
        def runner(x):
            with torch.no_grad():
                # input: h x w x c
                # output: c x h x w  and normalized, ready to pass into net.forward
                assert x.shape[0] == x.shape[1], 'we are only using square data, data format: N,H,W,C'
                if isinstance(x, torch.Tensor):  # for training
                    x = torch.cuda.FloatTensor(x.cuda())
                else:  # for testing
                    x = torch.cuda.FloatTensor(x.copy()).cuda()


                x = x.permute(0, 3, 1, 2) / 255.0 #.view(1, 3, 256, 256)
                x = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
                x = torch.rot90(x, k=2, dims=(2,3)) # Face north (this could be computed using a different world2agent transform)
                x = 2.0 * x - 1.0

#                 print(x.shape)
                return x.cpu()

        return runner, spaces.Box(-1, 1, output_size, dtype)
    return _thunk 
def map_pool_collated(output_size, dtype=np.float32):
    def _thunk(obs_space):
        def runner(x):
            with torch.no_grad():
                # input: n x h x w x c
                # output: n x c x h x w  and normalized, ready to pass into net.forward
                assert x.shape[2] == x.shape[1], 'we are only using square data, data format: N,H,W,C'
                if isinstance(x, torch.Tensor):  # for training
                    x = torch.cuda.FloatTensor(x.cuda())
                else:  # for testing
                    x = torch.cuda.FloatTensor(x.copy()).cuda()

                x = x.permute(0, 3, 1, 2) / 255.0 #.view(1, 3, 256, 256)
                x = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
                x = torch.rot90(x, k=2, dims=(2,3)) # Face north (this could be computed using a different world2agent transform)
                x = 2.0 * x - 1.0
                return x

        return runner, spaces.Box(-1, 1, output_size, dtype)
    return _thunk 
def apply_tta(input):
    inputs = []
    inputs.append(torch.flip(input, dims=[2]))
    inputs.append(torch.flip(input, dims=[3]))
    inputs.append(torch.rot90(input, k=1, dims=[2, 3]))
    inputs.append(torch.rot90(input, k=2, dims=[2, 3]))
    inputs.append(torch.rot90(input, k=3, dims=[2, 3]))
    inputs.append(torch.rot90(torch.flip(input, dims=[2]), k=1, dims=[2, 3]))
    inputs.append(torch.rot90(torch.flip(input, dims=[2]), k=3, dims=[2, 3]))
    return inputs 
def forward(self, x, y):
        if self.rotations:
            k_rot = random.choice([-1, 0, 1])
            x = torch.rot90(x, k_rot, [2, 3])
            y = torch.rot90(y, k_rot, [2, 3])
        if self.flips:
            if random.choice([True, False]):
                x = torch.flip(x, (2,))
                y = torch.flip(y, (2,))
            if random.choice([True, False]):
                x = torch.flip(x, (3,))
                y = torch.flip(y, (3,))
        return self.loss(x, y) 
def rot90(x, k=1):
    """rotate batch of images by 90 degrees k times"""
    return torch.rot90(x, k, (2, 3)) 
def forward(self, batch, angle):
        # rotation is couterclockwise
        k = angle // 90
        return torch.rot90(batch, k, (2, 3)) 
def transform(self, x):
        return torch.rot90(x, k=self.k, dims=(-1, -2)) 
def transform(self, x):
        return torch.rot90(x, k=self.k, dims=(-2, -1)) 
def backward(self, inp):
        return self.forward(inp)

# class Rot90Augment:
#     def __init__(self, k, dims):
#         self.k = k
#         self.dims = dims
#     def forward(self, inp):
#         return torch.rot90(inp, k=self.k, dims=self.dims)
#     def backward(self, inp):
#         return torch.rot90(inp, k=-self.k, dims=self.dims) 
def rot90(data: torch.Tensor, k: int, dims: Union[int, Sequence[int]]):
    Rotate 90 degrees around dims

        data: input data
        k: number of times to rotate
        dims: dimensions to mirror

        torch.Tensor: tensor with mirrored dimensions
    dims = [int(d + 2) for d in dims]
    return torch.rot90(data, int(k), dims) 
def train():
    net.train()  # enter train mode
    loss_avg = 0.0
    for bx, by in train_loader:
        curr_batch_size = bx.size(0)
        by_prime =, torch.ones(bx.size(0)),
                              2*torch.ones(bx.size(0)), 3*torch.ones(bx.size(0))), 0).long()
        bx = bx.numpy()
        # use torch.rot90 in later versions of pytorch
        bx = np.concatenate((bx, bx, np.rot90(bx, 1, axes=(2, 3)),
                             np.rot90(bx, 2, axes=(2, 3)), np.rot90(bx, 3, axes=(2, 3))), 0)
        bx = torch.FloatTensor(bx)
        bx, by, by_prime = bx.cuda(), by.cuda(), by_prime.cuda()

        adv_bx = adversary(net, bx, by, by_prime, curr_batch_size)

        # forward
        logits, pen = net(adv_bx * 2 - 1)

        # backward
        loss = F.cross_entropy(logits[:curr_batch_size], by)
        loss += 0.5 * F.cross_entropy(net.module.rot_pred(pen[curr_batch_size:]), by_prime)

        # exponential moving average
        loss_avg = loss_avg * 0.9 + float(loss) * 0.1

    state['train_loss'] = loss_avg

# test function 
def get_random_augmenters(
            ndim: int
    ) -> Tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[torch.Tensor], torch.Tensor]]:
        """Produce a pair of functions ``augment, reverse_augment``, where
        the ``augment`` function applies a random augmentation to a torch
        tensor and the ``reverse_augment`` function performs the reverse
        aumentations if applicable (i.e. for geometrical transformations)
        so pixel-level loss calculation is still correct).

        Note that all augmentations are performed on the compute device that
        holds the input, so generally on the GPU.

        # Random rotation angle (in 90 degree steps)
        k90 = torch.randint(0, 4, ()).item()
        # Get a random selection of spatial dims (ranging from [] to [2, 3, ..., example.ndim - 1]
        flip_dims_binary = torch.randint(0, 2, (ndim - 2,))
        flip_dims = (torch.nonzero(flip_dims_binary, as_tuple=False).squeeze(1) + 2).tolist()

        def augment(x: torch.Tensor) -> torch.Tensor:
            x = torch.rot90(x, +k90, (-1, -2))
            if len(flip_dims) > 0:
                x = torch.flip(x, flip_dims)

            # # Uncomment to enable additional random brightness and contrast augmentations
            # contrast_std = 0.1
            # brightness_std = 0.1
            # a = torch.randn(x.shape[:2], device=x.device, dtype=x.dtype) * contrast_std + 1.0
            # b = torch.randn(x.shape[:2], device=x.device, dtype=x.dtype) * brightness_std
            # for n in range(x.shape[0]):
            #     for c in range(x.shape[1]):
            #         # Formula based on tf.image.{adjust_contrast,adjust_brightness}
            #         # See
            #         m = torch.mean(x[n, c])
            #         x[n, c] = a[n, c] * (x[n, c] - m) + m + b[n, c]

            # # Uncomment to enable additional additive gaussian noise augmentations
            # agn_std = 0.1
            # x.add_(torch.randn_like(x).mul_(agn_std))

            return x

        def reverse_augment(x: torch.Tensor) -> torch.Tensor:
            if len(flip_dims) > 0:  # Check is necessary only on cuda
                x = torch.flip(x, flip_dims)
            x = torch.rot90(x, -k90, (-1, -2))
            return x

        return augment, reverse_augment 
def construct_occupancy_map(self) -> np.ndarray:
        # does not need second_last_pointgoal
        if len(self.history) == 0:
            return np.zeros((MAP_SIZE, MAP_SIZE, 3), dtype=np.uint8)

        cur_agent_pos_polar, cur_agent_heading = self.last_state[:2], self.last_state[2]
        cur_agent_pos_xy = convert_polar_to_xy(cur_agent_pos_polar)

        global_coords_polar = copy.deepcopy(np.array(self.history))[:,:2]  # throw away heading
        global_coords_xy = convert_polar_to_xy(global_coords_polar)

        # translate then rotate by negative angle only because we rotate everything by PI before return
        # rotation subtracts initial heading so that the initial agent always points 'north'
        agent_coords = global_coords_xy - cur_agent_pos_xy
        agent_coords = rotate(agent_coords, -1 * (cur_agent_heading - self.init_heading))
#         agent_coords = rotate(agent_coords, -1 * (np.pi + cur_agent_heading - self.init_heading))

        # calculate goal coordinates (independent of forward model)
        last_pointgoal_rotated = self.last_pointgoal #+ np.array([0, np.pi])
        goal = convert_polar_to_xy(last_pointgoal_rotated)
        goal_coords = np.array([goal])

        # quantize
        visitation_cells = pos_to_map(agent_coords + self.max_building_size / 2, cell_size=self.cell_size)
        goal_cells = pos_to_map(goal_coords + self.max_building_size / 2, cell_size=self.cell_size)

        # plot (make ambient pixels 128 so that they are 0 when pass into nn)
        omap = torch.full((3, MAP_SIZE, MAP_SIZE), fill_value=128, dtype=torch.uint8, device=None, requires_grad=False) # Avoid multiplies, stack, and copying to torch
        omap[0][visitation_cells[:, 0], visitation_cells[:, 1]] = 255 # Agent visitation
        omap[1][goal_cells[:, 0], goal_cells[:, 1]] = 255 # Goal
        # omap[2][visitation_cells[-1][0], visitation_cells[-1][1]] = 255 # Agent itself

        # omap = np.rot90(omap, k=2, axes=(0,1))
        # WARNING: with code checkpoints, we need the map to be rotated
        # omap = torch.rot90(omap, k=2, dims=(1,2)) # Face north (this could be computed using a different world2agent transform)
        if self.max_pool:
            omap = F.max_pool2d(omap.float(), kernel_size=3, stride=1, padding=1).byte()

        omap = omap.permute(1, 2, 0).cpu().numpy()
        assert omap.dtype == np.uint8, f'Omap needs to be uint8, currently {omap.dtype}'
        return omap