torch.segment_reduce#
- torch.segment_reduce(data: Tensor, reduce: str, *, lengths: Tensor | None = None, indices: Tensor | None = None, offsets: Tensor | None = None, axis: _int = 0, unsafe: _bool = False, initial: Number | _complex | None = None) Tensor#
- Perform a segment reduction operation on the input tensor along the specified axis. - Parameters
- Keyword Arguments
- lengths (Tensor, optional) – Length of each segment. Default: - None.
- offsets (Tensor, optional) – Offset of each segment. Default: - None.
- axis (int, optional) – The axis perform reduction. Default: - 0.
- unsafe (bool, optional) – Skip validation If True. Default: - False.
- initial (Number, optional) – The initial value for the reduction operation. Default: - None.
 
 - Example: - >>> data = torch.tensor([[1, 2, 3, 4],[5, 6, 7, 8],[9, 10, 11, 12]], dtype=torch.float32, device='cuda') >>> lengths = torch.tensor([2, 1], device='cuda') >>> torch.segment_reduce(data, 'max', lengths=lengths) tensor([[ 5., 6., 7., 8.], [ 9., 10., 11., 12.]], device='cuda:0')