Pytorch's sparse transform

Ulf Hamster 1 min.
python pytorch sparse matrix coo

%%capture
!pip install torch>=1.3.1
# load packages
import torch

# check version
print(f"torch version: {torch.__version__}")

# set GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"device type: {device}")
torch version: 1.3.1
device type: cpu

Functions

def init_coo_random(n_rows: int, n_cols: int, n_elem: int, 
                    dtype: torch.dtype = torch.float32, 
                    device: torch.device = torch.device("cpu")
                   ) -> torch.Tensor:
    with torch.no_grad():
        # generate random indicies
        indices = torch.stack([
            torch.randint(size=(n_elem,), high=n_rows),
            torch.randint(size=(n_elem,), high=n_cols)
        ]).to(device)
        # coalesce indicies
        indices = torch.unique(indices, sorted=True, dim=1)
        # generate normal distributed values
        n_elem2 = indices.size()[1]
        values = torch.randn((n_elem2,), device=device)
        # create COO matrix
        matrix = torch.sparse_coo_tensor(
            indices=indices, 
            values=values, 
            size=[n_rows, n_cols], 
            dtype=dtype,
            device=device
        ).coalesce()
    # done
    return matrix
# flip the indices
def transpose(x: torch.Tensor) -> torch.Tensor:
    with torch.no_grad():
        matrix = torch.sparse_coo_tensor(
            indices=torch.stack([x.indices()[1], x.indices()[0]]), 
            values=x.values(), 
            size=x.size(), 
            dtype=x.dtype,
            device=x.device).coalesce()
    return matrix

Create COO Matrix

%%time
n_rows, n_cols, n_elem = 10000, 10000, 100*10000
W = init_coo_random(n_rows, n_cols, n_elem,
                    dtype=torch.float32, device=device)
W
CPU times: user 7.52 s, sys: 1.32 s, total: 8.85 s
Wall time: 8.87 s

Transpose with custom function

%timeit Wt = transpose(W)
1 loop, best of 3: 277 ms per loop

Transpose with built-in function

%timeit Wt = W.t()
100 loops, best of 3: 10.8 ms per loop

Conclusion

The built-in function is faster.