Python源码示例:torch.jit()

示例1
def __init__(self, models, tgt_dict, max_iter=1, quantize=True, check_trace=True):
        super().__init__()
        src_tokens = torch.tensor([[4, 2]])
        src_lengths = torch.tensor([2])
        self.models = models

        generator = IterativeRefinementGenerator(
            self.models, tgt_dict, max_iter=max_iter
        )
        if quantize:
            generator = torch.quantization.quantize_dynamic(
                generator, {torch.nn.Linear}, dtype=torch.qint8, inplace=True
            )
        enc_inputs = (src_tokens, src_lengths)
        self.generator = torch.jit.trace(
            generator, enc_inputs, _force_outplace=True, check_trace=check_trace
        ) 
示例2
def conv_flop_jit(
    inputs: typing.List[object], outputs: typing.List[object]
) -> typing.Counter[str]:
    """
    This method counts the flops for convolution using torch script.

    Args:
        inputs (list(torch._C.Value)): The input shape in the form of a list of
            jit object before convolution.
        outputs (list(torch._C.Value)): The output shape in the form of a list
            of jit object after convolution.

    Returns:
        Counter: A Counter dictionary that records the number of flops for each
            operation.
    """
    # Inputs of Convolution should be a list of length 12. They represent:
    # 0) input tensor, 1) convolution filter, 2) bias, 3) stride, 4) padding,
    # 5) dilation, 6) transposed, 7) out_pad, 8) groups, 9) benchmark_cudnn,
    # 10) deterministic_cudnn and 11) user_enabled_cudnn.
    assert len(inputs) == 12, len(inputs)
    x, w = inputs[:2]
    x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0]))
    return conv_flop_count(x_shape, w_shape, out_shape) 
示例3
def batchnorm_flop_jit(
    inputs: typing.List[object], outputs: typing.List[object]
) -> typing.Counter[str]:
    """
    This method counts the flops for batch norm.

    Args:
        inputs (list(torch._C.Value)): The input shape in the form of a list of
            jit object before batch norm.
        outputs (list(torch._C.Value)): The output shape in the form of a list
            of jit object after batch norm.

    Returns:
        Counter: A Counter dictionary that records the number of flops for each
            operation.
    """
    # Inputs[0] contains the shape of the input.
    input_shape = get_shape(inputs[0])
    assert 2 <= len(input_shape) <= 5, input_shape
    flop = prod(input_shape) * 4
    flop_counter = Counter({"batchnorm": flop})
    return flop_counter 
示例4
def forward(self, enc, dec):
        """ Forward pass
        Arguments:
            enc: Tensor from the encoder pathway
            dec: Tensor from the decoder pathway (to be upconv'd)
        """

        updec = self.upconv(dec)
        enc, updec = autocrop(enc, updec)
        genc, att = self.attention(enc, dec)
        if not torch.jit.is_scripting():
            self.att = att
        updec = self.norm0(updec)
        updec = self.act0(updec)
        if self.merge_mode == 'concat':
            mrg = torch.cat((updec, genc), 1)
        else:
            mrg = updec + genc
        y = self.conv1(mrg)
        y = self.norm1(y)
        y = self.act1(y)
        y = self.conv2(y)
        y = self.norm2(y)
        y = self.act2(y)
        return y 
示例5
def __init__(self,
                 module,
                 example_inputs):

        super().__init__()
        self.module = module
        is_class = isinstance(module, torch.nn.Module)

        trace = torch.jit.trace(module, example_inputs, True)
        if not isinstance(example_inputs, (list, tuple)):
            example_inputs = [example_inputs]
        graph_py = parse(
            trace.graph, len(example_inputs), omit_useless_nodes=False, is_class=is_class)
        self.graph = graph_py
        self.trace = trace
        self.example_inputs = example_inputs

        msg = "input mismatch. this may due to some input isn't used in graph"
        assert len(example_inputs) + int(is_class) == len(graph_py.get_input_nodes_dict()), msg 
示例6
def forward(self, input, states):
        # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
        # List[LSTMState]: One state per layer
        output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
        output = input
        # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
        i = 0
        for rnn_layer in self.layers:
            state = states[i]
            output, out_state = rnn_layer(output, state)
            output_states += [out_state]
            i += 1
        return output, output_states


# Differs from StackedLSTM in that its forward method takes
# List[List[Tuple[Tensor,Tensor]]]. It would be nice to subclass StackedLSTM
# except we don't support overriding script methods.
# https://github.com/pytorch/pytorch/issues/10733 
示例7
def forward(self, input, states):
        # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
        # List[LSTMState]: One state per layer
        output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
        output = input
        # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
        i = 0
        for rnn_layer in self.layers:
            state = states[i]
            output, out_state = rnn_layer(output, state)
            # Apply the dropout layer except the last layer
            if i < self.num_layers - 1:
                output = self.dropout_layer(output)
            output_states += [out_state]
            i += 1
        return output, output_states 
示例8
def import_model(path=None):
    """
    Imports a model (as ScriptModule) from file.

    Parameters
    ----------
    path : str
        Path to where the model is saved. Defaults to the return value of the `get_model_path`
        function above.

    Returns
    -------
    torch.jit.ScriptModule
        The model file.
    """
    path = get_model_path() if path is None else path
    return torch.jit.load(path) 
示例9
def forward(self, src_tokens, src_lengths):
        # (seq_length, batch_size) for compatibility with Caffe2
        src_tokens_seq_first = src_tokens.t()

        futures = []
        for model in self.models:
            # evaluation mode
            model.eval()

            futures.append(
                torch.jit._fork(model.encoder, src_tokens_seq_first, src_lengths)
            )

        return self.get_outputs(src_tokens, futures) 
示例10
def save_to_pytorch(self, output_path):
        def pack(s):
            if hasattr(s, "_pack"):
                s._pack()

        def unpack(s):
            if hasattr(s, "_unpack"):
                s._unpack()

        self.apply(pack)
        torch.jit.save(self, output_path)
        self.apply(unpack) 
示例11
def forward(
        self,
        src_tokens: torch.Tensor,
        src_lengths: torch.Tensor,
        prev_token: torch.Tensor,
        prev_scores: torch.Tensor,
        attn_weights: torch.Tensor,
        prev_hypos_indices: torch.Tensor,
        num_steps: int,
    ) -> List[Tuple[Tensor, float, List[float], Tensor, Tensor]]:

        beam_search_out = self.beam_search(
            src_tokens,
            src_lengths,
            prev_token,
            prev_scores,
            attn_weights,
            prev_hypos_indices,
            num_steps,
        )
        all_tokens, all_scores, all_weights, all_prev_indices = beam_search_out

        outputs = torch.jit.annotate(
            List[Tuple[Tensor, float, List[float], Tensor, Tensor]], []
        )
        outputs = self.beam_decode(
            all_tokens, all_scores, all_weights, all_prev_indices, num_steps
        )

        return outputs 
示例12
def save_to_pytorch(self, output_path):
        def pack(s):
            if hasattr(s, "_pack"):
                s._pack()

        def unpack(s):
            if hasattr(s, "_unpack"):
                s._unpack()

        self.apply(pack)
        torch.jit.save(self, output_path)
        self.apply(unpack) 
示例13
def _get_all_end_states(
        self,
        beam_tokens: Tensor,
        beam_scores: Tensor,
        beam_prev_indices: Tensor,
        num_steps: int,
    ) -> Tensor:
        min_score = float("inf")
        min_index = -1
        end_states = torch.jit.annotate(List[Tensor], [])

        position = 1
        while bool(position <= num_steps + 1):
            for hyp_index in range(self.beam_size):
                if bool(beam_tokens[position][hyp_index] == self.eos_token_id) or bool(
                    position == num_steps + 1
                ):
                    hypo_score = float(beam_scores[position][hyp_index])
                    if bool(self.length_penalty != 0):
                        hypo_score = hypo_score / float(position) ** float(
                            self.length_penalty
                        )
                    end_states, min_score, min_index = self._add_to_end_states(
                        end_states,
                        min_score,
                        torch.tensor([hypo_score, float(position), float(hyp_index)]),
                        min_index,
                    )
            position = position + 1

        end_states = torch.stack(end_states)

        _, sorted_end_state_indices = end_states[:, 0].sort(dim=0, descending=True)
        end_states = end_states[sorted_end_state_indices, :]
        return end_states 
示例14
def save_to_pytorch(self, output_path):
        def pack(s):
            if hasattr(s, "_pack"):
                s._pack()

        def unpack(s):
            if hasattr(s, "_unpack"):
                s._unpack()

        self.apply(pack)
        torch.jit.save(self, output_path)
        self.apply(unpack) 
示例15
def generic_activation_jit(
    op_name: str,
) -> typing.Callable[[typing.List[object], typing.List[object]], typing.Counter[str]]:
    """
    This method return a handle that counts the number of activation from the
    output shape for the specified operation.

    Args:
        op_name (str): The name of the operation.

    Returns:
        typing.Callable: An activation handle for the given operation.
    """

    def _generic_activation_jit(outputs: typing.List[object]) -> int:
        """
        This is a generic jit handle that counts the number of activations for any
        operation given the output shape.

        Args:
            outputs (list(torch._C.Value)): The output shape in the form of a list
                of jit object.

        Returns:
            int: Total number of activations for each operation.
        """
        out_shape = get_shape(outputs[0])
        ac_count = prod(out_shape)
        return ac_count

    return lambda inputs, outputs: Counter({op_name: _generic_activation_jit(outputs)}) 
示例16
def get_shape(val: object) -> typing.List[int]:
    """
    Get the shapes from a jit value object.

    Args:
        val (torch._C.Value): jit value object.

    Returns:
        list(int): return a list of ints.
    """
    if val.isCompleteTensor():  # pyre-ignore
        return val.type().sizes()  # pyre-ignore
    else:
        raise ValueError() 
示例17
def addmm_flop_jit(
    inputs: typing.List[object], outputs: typing.List[object]
) -> typing.Counter[str]:
    """
    This method counts the flops for fully connected layers with torch script.

    Args:
        inputs (list(torch._C.Value)): The input shape in the form of a list of
            jit object.
        outputs (list(torch._C.Value)): The output shape in the form of a list
            of jit object.

    Returns:
        Counter: A Counter dictionary that records the number of flops for each
            operation.
    """
    # Count flop for nn.Linear
    # inputs is a list of length 3.
    input_shapes = [get_shape(v) for v in inputs[1:3]]
    # input_shapes[0]: [batch size, input feature dimension]
    # input_shapes[1]: [batch size, output feature dimension]
    assert len(input_shapes[0]) == 2, input_shapes[0]
    assert len(input_shapes[1]) == 2, input_shapes[1]
    batch_size, input_dim = input_shapes[0]
    output_dim = input_shapes[1][1]
    flop = batch_size * input_dim * output_dim
    flop_counter = Counter({"addmm": flop})
    return flop_counter 
示例18
def matmul_flop_jit(
    inputs: typing.List[object], outputs: typing.List[object]
) -> typing.Counter[str]:
    """
    This method counts the flops for matmul.

    Args:
        inputs (list(torch._C.Value)): The input shape in the form of a list of
            jit object before matmul.
        outputs (list(torch._C.Value)): The output shape in the form of a list
            of jit object after matmul.

    Returns:
        Counter: A Counter dictionary that records the number of flops for each
            operation.
    """
    # Inputs should be a list of length 2.
    # Inputs contains the shapes of two matrices.
    input_shapes = [get_shape(v) for v in inputs]
    assert len(input_shapes) == 2, input_shapes
    assert len(input_shapes[1]) == 2, input_shapes
    assert input_shapes[0][-1] == input_shapes[1][0], input_shapes
    batch_dim = input_shapes[0][0]
    m1_dim, m2_dim = input_shapes[1]
    flop = m1_dim * m2_dim * batch_dim
    flop_counter = Counter({"matmul": flop})
    return flop_counter 
示例19
def __call__(self, *args, **kwargs):
        method_model = _ForwardOverrideModel(self.model, self.method_name)
        example_inputs = {
            self.method_name: kwargs if len(kwargs) > 0 else args
        }

        # noinspection PyTypeChecker
        self.tracing_result = torch.jit.trace(
            method_model, example_inputs=example_inputs
        ) 
示例20
def test_cutmix():
    input = torch.empty(4, 3, 32, 32)
    target = torch.tensor([1, 2, 3, 4], dtype=torch.long)
    cutmix(input, target, 0.1)

    jit_cutmix = torch.jit.script(cutmix)
    jit_cutmix(input, target, 0.1) 
示例21
def test_cutmix():
    input = torch.empty(4, 3, 32, 32)
    target = torch.tensor([1, 2, 3, 4], dtype=torch.long)
    mixup(input, target, 0.1)

    jit_mixup = torch.jit.script(mixup)
    jit_mixup(input, target, 0.1) 
示例22
def forward(self, input, state):
        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
        inputs = reverse(input.unbind(0))
        outputs = jit.annotate(List[Tensor], [])
        for i in range(len(inputs)):
            out, state = self.cell(inputs[i], state)
            outputs += [out]
        return torch.stack(reverse(outputs)), state 
示例23
def forward(self, input, states):
        # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
        # List[LSTMState]: [forward LSTMState, backward LSTMState]
        outputs = jit.annotate(List[Tensor], [])
        output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
        # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
        i = 0
        for direction in self.directions:
            state = states[i]
            out, out_state = direction(input, state)
            outputs += [out]
            output_states += [out_state]
            i += 1
        return torch.cat(outputs, -1), output_states 
示例24
def forward(self, input, states):
        # type: (Tensor, List[List[Tuple[Tensor, Tensor]]]) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]]
        # List[List[LSTMState]]: The outer list is for layers,
        #                        inner list is for directions.
        output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], [])
        output = input
        # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
        i = 0
        for rnn_layer in self.layers:
            state = states[i]
            output, out_state = rnn_layer(output, state)
            output_states += [out_state]
            i += 1
        return output, output_states 
示例25
def export_model(model, path=None, input_shape=(1, 3, 64, 64)):
    """
    Exports the model. If the model is a `ScriptModule`, it is saved as is. If not,
    it is traced (with the given input_shape) and the resulting ScriptModule is saved
    (this requires the `input_shape`, which defaults to the competition default).

    Parameters
    ----------
    model : torch.nn.Module or torch.jit.ScriptModule
        Pytorch Module or a ScriptModule.
    path : str
        Path to the file where the model is saved. Defaults to the value set by the
        `get_model_path` function above.
    input_shape : tuple or list
        Shape of the input to trace the module with. This is only required if model is not a
        torch.jit.ScriptModule.

    Returns
    -------
    str
        Path to where the model is saved.
    """
    path = get_model_path() if path is None else path
    model = deepcopy(model).cpu().eval()
    if not isinstance(model, torch.jit.ScriptModule):
        assert input_shape is not None, "`input_shape` must be provided since model is not a " \
                                        "`ScriptModule`."
        traced_model = trace(model, torch.zeros(*input_shape))
    else:
        traced_model = model
    torch.jit.save(traced_model, path)
    return path 
示例26
def make_representor(model, cuda=None):
    """
    Encloses the pytorch ScriptModule in a callable that can be used by `disentanglement_lib`.

    Parameters
    ----------
    model : torch.nn.Module or torch.jit.ScriptModule
        The Pytorch model.
    cuda : bool
        Whether to use CUDA for inference. Defaults to the return value of the `use_cuda`
        function defined above.

    Returns
    -------
    callable
        A callable function (`representation_function` in dlib code)
    """
    # Deepcopy doesn't work on ScriptModule objects yet:
    # https://github.com/pytorch/pytorch/issues/18106
    # model = deepcopy(model)
    cuda = use_cuda() if cuda is None else cuda
    model = model.cuda() if cuda else model.cpu()

    # Define the representation function
    def _represent(x):
        assert isinstance(x, np.ndarray), \
            "Input to the representation function must be a ndarray."
        assert x.ndim == 4, \
            "Input to the representation function must be a four dimensional NHWC tensor."
        # Convert from NHWC to NCHW
        x = np.moveaxis(x, 3, 1)
        # Convert to torch tensor and evaluate
        x = torch.from_numpy(x).float().to('cuda' if cuda else 'cpu')
        with torch.no_grad():
            y = model(x)
        y = y.cpu().numpy()
        assert y.ndim == 2, \
            "The returned output from the representor must be two dimensional (NC)."
        return y

    return _represent 
示例27
def forward(self, inputs, state):
        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
        outputs = jit.annotate(List[Tensor], [])
        seq_len = inputs.size(0)
        for i in range(seq_len):
            out, state = self.cell(inputs[seq_len - i - 1], state)
            # workaround for the lack of list rev support
            outputs = [out] + outputs
        return torch.stack(outputs), state 
示例28
def forward(self, input, states):
        # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
        # List[LSTMState]: [forward LSTMState, backward LSTMState]
        outputs = jit.annotate(List[Tensor], [])
        output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
        for (i, direction) in enumerate(self.directions):
            state = states[i]
            out, out_state = direction(input, state)
            outputs += [out]
            output_states += [out_state]
        # tensor array concat assumes axis == 0 for now
        # return torch.cat(outputs, -1), output_states
        return torch.cat(outputs, 0), output_states 
示例29
def forward(self, input, states):
        # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
        # List[LSTMState]: One state per layer
        output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
        output = input
        for (i, rnn_layer) in enumerate(self.layers):
            state = states[i]
            output, out_state = rnn_layer(output, state)
            output_states += [out_state]
        return output, output_states 
示例30
def forward(self, input, states):
        # type: (Tensor, List[List[Tuple[Tensor, Tensor]]]) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]]
        # List[List[LSTMState]]: The outer list is for layers,
        #                        inner list is for directions.
        output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], [])
        output = input
        for (i, rnn_layer) in enumerate(self.layers):
            state = states[i]
            output, out_state = rnn_layer(output, state)
            output_states += [out_state]
        return output, output_states