Benchmark Sparse Matrix-Vector Multiplication

Ulf Hamster 6 min.
python sparse scipy

Summary

We take the spiking RNN from the previous notebook and compare different types of sparse matrices for weights W and states s

The results

Load Packages

import numpy as np

from scipy.sparse import (
    bsr_matrix, coo_matrix, csc_matrix, csr_matrix,
    dia_matrix, dok_matrix, lil_matrix)

from line_profiler import LineProfiler
#import numpy as np
#from scipy.sparse import coo_matrix

def randn_coo_matrix(n_rows, n_cols, n_elem, nodiag=True, dtype=np.float32):
    # generate random indicies
    idx_rows = np.random.randint(0, high=n_rows, size=(n_elem,))
    idx_cols = np.random.randint(0, high=n_cols, size=(n_elem,))
    # delete diagonal elements
    if nodiag:
        mask = idx_rows != idx_cols
        idx_rows = idx_rows[mask]
        idx_cols = idx_cols[mask]
    # coalesce
    indicies = np.unique(np.c_[idx_rows, idx_cols], axis=0)
    n_elem2, _ = indicies.shape
    # generate normal distributed values
    values = np.random.standard_normal(size=(n_elem2,))
    # create COO matrix
    return coo_matrix((values, (indicies[:,0], indicies[:,1])),
                      shape=(n_rows, n_cols))

Model Settings

np.random.seed(42)  # Reproducibility
dtype_flt = np.float32

n_iter = 5000  # Number of time steps to simulate
n_states = 15000  # Number of states (and "neuron")

# Transfer function
from scipy.special import expit
f_fn = expit  

# EWA smoothing factor Gamma ("input resistance")
gam = (0.01 + 0.09 * np.random.rand(n_states)).astype(dtype=dtype_flt)

# Firing Threshold
ups = (0.05 + 0.05 * np.random.rand(n_states)).astype(dtype=dtype_flt)

# Activation function
g_fn = lambda x, a: (x > a).astype(np.uint8)

# Initial States
s0 = np.ones(shape=(n_states,)).astype(np.uint8)

# Weight Matrix
W = 1. * randn_coo_matrix(n_states, n_states, 10*n_states).astype(dtype=dtype_flt)

Test 1

def forward1(W, s0, f_fn, gam, ups, g_fn, n_states, n_iter, dtype_flt=np.float64):
    # Allocate computer memory
    s = np.zeros(shape=(n_states, n_iter), order='F', dtype=np.uint8)
    s[:, 0] = s0
    u = np.zeros(shape=(n_states, n_iter), order='F', dtype=dtype_flt)
    mem = (1 - gam).astype(dtype=dtype_flt)
    
    for t in range(1, n_iter):
        # internal value u ("membrane potential")
        # u[:, t] = gam * f_fn(W.dot(s[:, t-1])) + mem * u[:, t-1]
        tmp1 = W.dot(s[:, t-1])
        tmp2 = f_fn(tmp1)
        tmp3 = gam * tmp2
        tmp4 = mem * u[:, t-1]
        tmp5 = tmp3 + tmp4
        u[:, t] = tmp5

        # state value s (spikes)
        s[:, t] = g_fn(u[:, t], ups)

        # if signal fired s=1 then reset u=0
        mask = s[:, t].astype(bool)
        u[mask, t] = 0
    
    return s, u

Test 1a - W:dense, s:dense

#W1a = W.toarray()
#%time s, u = forward1(W1a, s0, f_fn, gam, ups, g_fn, n_states, n_iter)
CPU times: user 6min 26s, sys: 1.9 s, total: 6min 28s
Wall time: 3min 22s

Test 1b - W:sparse, s:dense

for spm in [bsr_matrix, coo_matrix, csc_matrix, csr_matrix]:
    print(spm.__name__)
    %time s, u = forward1(spm(W), s0, f_fn, gam, ups, g_fn, n_states, n_iter)
    print("")
bsr_matrix
CPU times: user 4.21 s, sys: 289 ms, total: 4.5 s
Wall time: 4.57 s

coo_matrix
CPU times: user 3.51 s, sys: 329 ms, total: 3.84 s
Wall time: 3.85 s

csc_matrix
CPU times: user 4.54 s, sys: 309 ms, total: 4.85 s
Wall time: 4.99 s

csr_matrix
CPU times: user 4.01 s, sys: 303 ms, total: 4.32 s
Wall time: 4.32 s
# play with
#%time s, u = forward1(bsr_matrix(W), s0, f_fn, gam, ups, g_fn, n_states, n_iter)

Line Profiling

lp = LineProfiler()
run = lp(forward1)
run(coo_matrix(W), s0, f_fn, gam, ups, g_fn, n_states, n_iter)
lp.print_stats()
Timer unit: 1e-06 s

Total time: 3.96694 s
File: <ipython-input-4-9cab4e0ffaf1>
Function: forward1 at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     1                                           def forward1(W, s0, f_fn, gam, ups, g_fn, n_states, n_iter, dtype_flt=np.float64):
     2                                               # Allocate computer memory
     3         1      17479.0  17479.0      0.4      s = np.zeros(shape=(n_states, n_iter), order='F', dtype=np.uint8)
     4         1         21.0     21.0      0.0      s[:, 0] = s0
     5         1         35.0     35.0      0.0      u = np.zeros(shape=(n_states, n_iter), order='F', dtype=dtype_flt)
     6         1         67.0     67.0      0.0      mem = (1 - gam).astype(dtype=dtype_flt)
     7                                               
     8      5000       5527.0      1.1      0.1      for t in range(1, n_iter):
     9                                                   # internal value u ("membrane potential")
    10                                                   # u[:, t] = gam * f_fn(W.dot(s[:, t-1])) + mem * u[:, t-1]
    11      4999    1471066.0    294.3     37.1          tmp1 = W.dot(s[:, t-1])
    12      4999     856024.0    171.2     21.6          tmp2 = f_fn(tmp1)
    13      4999      56538.0     11.3      1.4          tmp3 = gam * tmp2
    14      4999     101891.0     20.4      2.6          tmp4 = mem * u[:, t-1]
    15      4999     131849.0     26.4      3.3          tmp5 = tmp3 + tmp4
    16      4999     477639.0     95.5     12.0          u[:, t] = tmp5
    17                                           
    18                                                   # state value s (spikes)
    19      4999     162944.0     32.6      4.1          s[:, t] = g_fn(u[:, t], ups)
    20                                           
    21                                                   # if signal fired s=1 then reset u=0
    22      4999      21466.0      4.3      0.5          mask = s[:, t].astype(bool)
    23      4999     664397.0    132.9     16.7          u[mask, t] = 0
    24                                               
    25         1          1.0      1.0      0.0      return s, u

Test 2 - W:sparse and s:sparse

def forward2(spm, W_, s0, f_fn, gam_, ups, g_fn, n_states, n_iter, dtype_flt=np.float64):
    # Allocate computer memory
    u = np.zeros(shape=(n_states, n_iter), order='F', dtype=dtype_flt)
    gam = gam_.reshape(-1, 1).astype(dtype=dtype_flt)
    mem = (1 - gam).astype(dtype=dtype_flt)

    # apply sparse matrix type (spm)
    s = [spm(s0).T]  # append s0 as sparse column vector to list
    W = spm(W_)

    for t in range(1, n_iter):
        # internal value u ("membrane potential")
        #u[:, t] = gam * f_fn(W.dot(s[t-1].T).toarray().reshape(-1)) + mem * u[:, t-1]
        tmp1 = W.dot(s[t-1])  # spm
        tmp1.data = f_fn(tmp1.data)  
        tmp3 = tmp1.multiply(gam)  # COO?
        tmp4 = mem * u[:, t-1].reshape(-1, 1)  # ndarray
        tmp5 = tmp3 + tmp4  # matrix
        u[:, t] = tmp5.reshape(-1)

        # state value s (spikes)
        s.append(spm(g_fn(u[:, t], ups)).T)

        # if signal fired s=1 then reset u=0
        idx = s[t].nonzero()[1]
        u[idx, t] = 0
    
    return s, u

Run Tests

for spm in [bsr_matrix, coo_matrix, csc_matrix, csr_matrix]:
    print(spm.__name__)
    %time s, u = forward2(spm, W, s0, f_fn, gam, ups, g_fn, n_states, n_iter)
    print("")
bsr_matrix
CPU times: user 32.3 s, sys: 892 ms, total: 33.2 s
Wall time: 33.3 s

coo_matrix
CPU times: user 41.7 s, sys: 1.17 s, total: 42.8 s
Wall time: 42.9 s

csc_matrix
CPU times: user 20.2 s, sys: 1.14 s, total: 21.3 s
Wall time: 21.4 s

csr_matrix
CPU times: user 25.3 s, sys: 934 ms, total: 26.2 s
Wall time: 26.3 s
# play with
#%time s, u = forward2(bsr_matrix, W, s0, f_fn, gam, ups, g_fn, n_states, n_iter)

Line Profiling

lp = LineProfiler()
run = lp(forward2)
run(csc_matrix, W, s0, f_fn, gam, ups, g_fn, n_states, n_iter)
lp.print_stats()
Timer unit: 1e-06 s

Total time: 25.9339 s
File: <ipython-input-9-98c50d8d5a5a>
Function: forward2 at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     1                                           def forward2(spm, W_, s0, f_fn, gam_, ups, g_fn, n_states, n_iter, dtype_flt=np.float64):
     2                                               # Allocate computer memory
     3         1         41.0     41.0      0.0      u = np.zeros(shape=(n_states, n_iter), order='F', dtype=dtype_flt)
     4         1        125.0    125.0      0.0      gam = gam_.reshape(-1, 1).astype(dtype=dtype_flt)
     5         1        274.0    274.0      0.0      mem = (1 - gam).astype(dtype=dtype_flt)
     6                                           
     7                                               # apply sparse matrix type (spm)
     8         1       5036.0   5036.0      0.0      s = [spm(s0).T]  # append s0 as sparse column vector to list
     9         1       3706.0   3706.0      0.0      W = spm(W_)
    10                                           
    11      5000       6690.0      1.3      0.0      for t in range(1, n_iter):
    12                                                   # internal value u ("membrane potential")
    13                                                   #u[:, t] = gam * f_fn(W.dot(s[t-1].T).toarray().reshape(-1)) + mem * u[:, t-1]
    14      4999   11072711.0   2215.0     42.7          tmp1 = W.dot(s[t-1])  # spm
    15      4999     850647.0    170.2      3.3          tmp1.data = f_fn(tmp1.data)  
    16      4999    3239310.0    648.0     12.5          tmp3 = tmp1.multiply(gam)  # COO?
    17      4999     149486.0     29.9      0.6          tmp4 = mem * u[:, t-1].reshape(-1, 1)  # ndarray
    18      4999    1107497.0    221.5      4.3          tmp5 = tmp3 + tmp4  # matrix
    19      4999     611586.0    122.3      2.4          u[:, t] = tmp5.reshape(-1)
    20                                           
    21                                                   # state value s (spikes)
    22      4999    6746837.0   1349.6     26.0          s.append(spm(g_fn(u[:, t], ups)).T)
    23                                           
    24                                                   # if signal fired s=1 then reset u=0
    25      4999    1757353.0    351.5      6.8          idx = s[t].nonzero()[1]
    26      4999     382576.0     76.5      1.5          u[idx, t] = 0
    27                                               
    28         1          1.0      1.0      0.0      return s, u