Commit 9521579c authored by pablosmig's avatar pablosmig
Browse files

HYGRIP first commit

parents
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
from sklearn.decomposition import FastICA
import warnings
from .pipelines import *
def compute_ica(x, n_components=None, tol=1e-3, max_iter=500):
""" Computes sources and mixing matrices.
args:
x : array, shape (num_samps, num_chans)
** thanks to ** scikit-learn.org
"""
assert len(x.shape) == 2, f"x is {len(x.shape)}D but only 2D is accepted"
assert x.shape[0] > x.shape[1], f"x shape is {x.shape} but time axis should be last"
#  num_components, num_channels by default
if n_components is None:
n_components = x.shape[1]
ica = FastICA(n_components=n_components, whiten=True, max_iter=max_iter, tol=tol)
converged = True
with warnings.catch_warnings(record=True)as w:
S_ = ica.fit_transform(x) # (num_samps, num_comps)
if len(w) == 1 and "did not converge" in str(w[-1].message):
converged = False
# mixing_matrix = pinv(components): sources (-> white data) -> sources
M_ = ica.mixing_ # (num_chans, num_comps)
# pinv(np.dot(unmixing_matrix, self.whitening_)) : data (-> white_data) -> sources
U_ = ica.components_ # (num_comps, num_chans)
#  mean over features to undo whitening
mean_ = ica.mean_ # (num_chans,)
return S_, M_, U_, mean_, converged
def align(x, y, fs, threshold=1):
""" Aligns to signals with same sfreq in time
args:
x : array, size (tsamps,)
y : array, size (tsamps',)
threshold : float, maxmimum delay possible
"""
assert len(x.shape) == 1 and len(y.shape) == 1, \
f"align only accepts 1D vectors but got {len(x.shape)} and {len(y.shape)}"
min_len = np.min([x.shape[0], y.shape[0]])
x, y = x[:min_len], y[:min_len]
cor = signal.correlate(x, y, mode="full")
# rectify (signals can be of opposite sign only)
cor = np.abs(cor)
mx_ind = np.argmax(cor)
delay = mx_ind - cor.shape[0] // 2
if np.abs(delay * fs) <= threshold:
if delay > 0:
x, y = x[delay:], y[:-delay]
elif delay < 0:
x, y = x[:delay], y[-delay:]
return x, y
def reject(S_, ref, max_num, threshold, fs, _debug=False):
""" Find estimated sources to reject based on `ref` signals. Signals should be high pass filtered.
Rejects the max_num of sources that are most correlated to any of the ref signals.
args:
S_ : array, shape (num_samps, num_comps)
ref: array, shape (num_samps')
"""
assert len(S_.shape) == 2, f"`S_` should be 2D array but found {len(S_.shape)} dims"
assert len(ref.shape) == 2, f"`ref` should be 2D array but found {len(ref.shape)} dims"
pears = []
for i in range(S_.shape[1]):
paux = []
for j in range(ref.shape[0]):
# avoid operations overwriting memory
x_, y_ = np.copy(S_[:, i]), np.copy(ref[j])
# standardise signals (should have been highpass filtered already)
x_, y_ = x_ / x_.std(), y_ / y_.std()
# align
x_, y_ = align(x_, y_, fs, threshold)
# compute pearson correlation
paux.append(np.corrcoef(x_, y_)[0, 1])
# select max correlation to any reference
p = max(paux)
pears.append(p)
pears = np.abs(np.array(pears))
# select max_num of sources with highest correlation above threshold
rejected = []
max_inds = list(pears.argsort()[-max_num:][::-1])
for i in max_inds:
if pears[i] > threshold:
rejected.append(i)
if _debug and len(rejected) > 0:
plt.plot(S_[:, rejected[0]], alpha=0.5)
plt.plot(ref.T, alpha=0.2)
plt.gcf().set_size_inches(16, 4)
plt.show()
return rejected, pears
def reconstruct(x, M_, U_, mean_, reject):
""" Reconstruct original signal without rejected estimated sources
Data should not be standardised since U_/M_ already whiten/color the signals
args:
X : original signal, array (num_samps, num_chans)
U_ : unmixing matrix, array (num_comps, num_chans)
M_ : array, shape (num_chans, num_comps)
pinv(np.dot(unmixing_matrix, self.whitening_))
mean_ : array, (num_chans,)
reject : list of int
components to zero in the recovered sources
"""
assert type(reject) == list, \
f"rejected_components is type {type(reject)} but can only be list"
# Estimate sources
S_ = np.dot(x, U_.T) # (num_samps, num_comps)
# zero rejected recovered sources
for i in reject:
S_[:, i] = 0
rec = np.dot(S_, M_.T) + mean_ # (num_samps, num_chans)
return rec
class IcaReconstructor(object):
""" ICA gives better results when there are more samples. The `gathered` mode concatenates samples per channel.
On the other if there are many sources along a recording this might be less well captured. If we guess there are
different sources for different conditions or at different recording times selecting the time-periods (`tlim`) to
apply the ICA to might improve results.
"""
def __init__(self, fs, n_components=None, max_components=2, threshold=0.3, tolerance=1e-2, max_iter=500,
prefilter=None, tlim=None, rec_tlim=None, mode="continuous", standarize=False, verbose=True):
self._fs = fs
self._n_components = n_components
self._max_components = max_components
self._threshold = threshold
self._prefilter = prefilter
self._cropper = StartEndCropper()
self._tol = tolerance
self._max_iter = max_iter
if mode == "per_epoch":
self._tlim = tlim
self._rec_tlim = rec_tlim
else:
self._tlim = None
self._rec_tlim = None
self._stdze = standarize
self._mode = mode
self._verbose = verbose
if n_components is not None:
assert max_components <= n_components, \
"max_num should be greater than n_components"
def __repr__(self):
if self._n_components is None:
n = '\"as channels\"'
else:
n = str(self._n_components)
return "IcaPre" + self._prefilter.__repr__() + \
f"\tIcaReconstructor(sfreq[Hz]={self._fs}, n_comps={n}, max={self._max_components}," + \
f" thres={self._threshold}, tol={self._tol}, max_iter={self._max_iter}, tlim={self._tlim}," + \
f" mode={self._mode}, standarize={self._stdze})\n"
def __call__(self, x, evts, fs, **kwargs):
"""
"""
assert "ref" in kwargs, \
'IcaReconstructor requires a reference signal with key \"ref\"'
assert "ref_fs" in kwargs, \
'IcaReconstructor requires a reference sampling frequency with key \"ref_fs\"'
assert "ref_evts" in kwargs, \
'IcaReconstructor requires reference events with key \"ref_evts\"'
ref, ref_fs, ref_evts = kwargs["ref"], kwargs["ref_fs"], kwargs["ref_evts"]
assert len(x.shape) == len(ref.shape), \
f"reference signal ({len(ref.shape)}D) and input ({len(x.shape)}D) should have undergone same epoching"
if self._mode == "per_epoch":
assert len(x.shape) > 2, \
f"invalid input shape ({len(x.shape)}D) for \'per_epoch\' mode"
if self._mode == "per_condition":
assert len(x.shape) == 4, \
f"invalid input shape ({len(x.shape)}D) for \'per_condition\' mode"
#  copy to not corrupt originals
x_, ref_ = np.copy(x), np.copy(ref)
evts_, ref_evts_ = np.copy(evts), np.copy(ref_evts)
fs_, ref_fs_ = np.copy(fs), np.copy(ref_fs)
# crop start to end of recordings
if len(x.shape) == 2:
x_, evts_, fs_ = self._cropper(x_, evts_, fs_)
ref_, ref_evts_, ref_fs_ = self._cropper(ref_, ref_evts_, ref_fs_)
# check input is numpy array so we can write on it
if type(x) != np.array:
x = x[()]
# prestandarise in time axis
if self._stdze:
x_ /= x_.std(-1, keepdims=True)
ref_ /= ref_.std(-1, keepdims=True)
x /= x.std(-1, keepdims=True)
#  decimate to ICA sfreq
x_ = decimate(x_, fs, self._fs, axis=-1)
ref_ = decimate(ref_, ref_fs, self._fs, axis=-1)
#  apply filter to center and remove offset
if self._prefilter is not None:
x_, _, fs_ = self._prefilter(x_, None, self._fs)
ref_, _, ref_fs_ = self._prefilter(ref_, None, self._fs)
#  Indices for ICA computation
t_on = evts[0, 0, 0]
times_ = np.arange(x_.shape[-1]) / self._fs - t_on
if self._tlim is None:
inds = [0, -1]
else:
inds = [np.argmin(np.abs(times_ - t)) for t in self._tlim]
# Indices for ICA reconstruction
if self._rec_tlim is None:
inds_rec = [0, -1]
else:
inds_rec = [np.argmin(np.abs(times_ - t)) for t in self._rec_tlim]
reshape_x = False
if self._mode == "per_condition":
# crop to desired time section
x_, ref_ = x_[..., inds[0]:inds[1]], ref_[..., inds[0]:inds[1]]
# serialize (conds, chans, times * trials)
x_ = np.moveaxis(x_, 1, -1)
ref_ = np.moveaxis(ref_, 1, -1)
x_ = x_.reshape((x_.shape[0], x_.shape[1], -1), order='F')
ref_ = ref_.reshape((ref_.shape[0], ref_.shape[1], -1), order='F')
for i in range(x_.shape[0]):
#  compute ICA
S_, M_, U_, mean_, conv = compute_ica(x_[i].T, n_components=self._n_components, tol=self._tol,
max_iter=self._max_iter)
rejected, pears = reject(S_, ref_[i], self._max_components, self._threshold, self._fs)
# Reconstruct original epochs on desired time
for j in range(x.shape[1]):
x[i, j, :, inds_rec[0]:inds_rec[1]] = reconstruct(x[i, j, :, inds_rec[0]:inds_rec[1]].T, M_, U_,
mean_, rejected).T
if self._verbose:
print(f"\t cond {i}: conv {conv}: {len(rejected)} rejected component(s) above corrcoef_thres = {self._threshold} : {pears[rejected]}")
elif self._mode in ["continuous", "per_epoch"]:
if len(x_.shape) == 4:
x_ = np.concatenate([x_[0], x_[1]], 0)
ref_ = np.concatenate([ref_[0], ref_[1]], 0)
if len(x.shape) == 4:
x = np.concatenate([x[0], x[1]], 0)
reshape_x = True
if self._mode == "continuous":
if len(x_.shape) == 3:
# crop to desired time section
t_on = evts[0, 0, 0] if len(evts_.shape) == 3 else evts[0, 0]
times_ = np.arange(x_.shape[-1]) / self._fs - t_on
if self._tlim is None:
inds = [0, -1]
else:
inds = [np.argmin(np.abs(times_ - t)) for t in self._tlim]
x_, ref_ = x_[..., inds[0]:inds[1]], ref_[..., inds[0]:inds[1]]
# serialize (chans, times * trials * conds)
x_ = np.moveaxis(x_, 0, -1)
ref_ = np.moveaxis(ref_, 0, -1)
x_ = x_.reshape((x_.shape[0], -1), order='F')
ref_ = ref_.reshape((ref_.shape[0], -1), order='F')
#  compute ICA
S_, M_, U_, mean_, conv = compute_ica(
x_.T, n_components=self._n_components, tol=self._tol, max_iter=self._max_iter
)
rejected, pears = reject(S_, ref_, self._max_components, self._threshold, self._fs)
# Reconstruct original...
if len(x.shape) == 2: #  ... continuous recording
x = reconstruct(x.T, M_, U_, mean_, rejected).T
else: # ... epochs
for i in range(x.shape[0]):
x[i, ..., inds_rec[0]:inds_rec[-1]] = reconstruct(x[i, ..., inds_rec[0]:inds_rec[-1]].T, M_, U_,
mean_, rejected).T
if self._verbose:
print(f"\tconv {conv} : {len(rejected)} rejected component(s) above corrcoef_thres = {self._threshold} : {pears[rejected]}")
elif self._mode == "per_epoch":
for i in range(x.shape[0]):
S_, M_, U_, mean_, conv = compute_ica(x_[i].T, n_components=self._n_components, tol=self._tol,
max_iter=self._max_iter)
rejected, pears = reject(S_, ref_[i], self._max_components, self._threshold, self._fs)
if self._verbose:
print(f"\tepoch {i} : conv {conv} : {len(rejected)} rejected component(s) above corrcoef_thres = {self._threshold} : {pears[rejected]}")
x[i, ..., inds_rec[0]:inds_rec[-1]] = reconstruct(
x[i, ..., inds_rec[0]:inds_rec[-1]].T, M_, U_, mean_, rejected
).T
else:
raise NotImplementedError(f"{self._mode} not implemented")
if reshape_x:
x0, x1 = np.split(x, 2, axis=0)
x0, x1 = np.expand_dims(x0, 0), np.expand_dims(x1, 0)
x = np.concatenate((x0, x1), 0)
return x, evts, fs
\ No newline at end of file
import numpy as np
from scipy import signal
import multiprocessing
from joblib import Parallel, delayed
def primes(n, threshold):
primfac = []
d = 2
while d * d <= n:
while (n % d) == 0:
assert d < threshold, \
f"factors are greater than threshold {threshold}"
primfac.append(d)
n //= d
d += 1
assert n < threshold, \
f"factors are greater than threshold {threshold}"
if (n > 1):
primfac.append(n)
return primfac
def decimate(x, in_fs, out_fs, axis=-1, threshold=13):
""" Decimates signal to out_fs avoiding downsampling factors greater than `thres`
"""
if in_fs != out_fs:
assert in_fs % out_fs == 0, \
f"input sfreq ({in_fs}Hz) should be divisible by output sfreq ({out_fs}Hz)"
dsf = in_fs / out_fs
while dsf % 10 == 0:
x = signal.decimate(x, 10, axis=axis)
dsf /= 10
pf = primes(dsf, threshold)
for dsf in pf:
x = signal.decimate(x, int(dsf), axis=axis)
return x
class Pipeline(object):
def __init__(self, processes, name=""):
self._name = name
if len(processes) == 0:
self._processes = None
else:
self._processes = processes
def __repr__(self):
s = f"Pipeline({self._name})\n"
if self._processes is None:
s += f"\tplaceholder"
else:
for proc in self._processes:
s += f"\t{proc.__repr__()}"
return s
def __call__(self, x, evts, fs, **kwargs):
if self._processes is not None:
for proc in self._processes:
x, evts, fs = proc(x, evts, fs, **kwargs)
return x, evts, fs
# Crop
class EpochExtractor(object):
def __init__(self, tbound):
self._tbound = tbound
def __repr__(self):
return f"EpochExtractor(t_0[s]={self._tbound[0]}, t_end[s]={self._tbound[1]})\n"
def __call__(self, x, evts, fs, **kwargs):
events, epochs = [], []
times = np.arange(x.shape[-1]) / fs
inds = [int(t * fs) for t in self._tbound]
for i in range(evts.shape[0]):
if evts[i, 1] != -1:
# find crop indices {t_evt+t_bound[0], t_evt+t_bound[1]}
ind0 = int(np.argmin(np.abs(times - evts[i, 0])))
# append cropped epoch
epochs.append(
np.expand_dims(x[:, ind0 + inds[0]:ind0 + inds[1]], axis=0)
)
# append time shifted event
events.append(
np.expand_dims(
np.array((-self._tbound[0], evts[i, 1])),
axis=0)
)
epochs = np.concatenate(epochs, axis=0)
events = np.concatenate(events, axis=0)
return epochs, events, fs
# Gather by labels
class LabelGatherer(object):
def __init__(self, labels=[0, 1]):
self._labels = labels
def __repr__(self):
return f"LabelGatherer()\n"
def __call__(self, x, evts, fs, **kwargs):
left_inds = evts[:, 1] == 0
right_inds = evts[:, 1] == 1
left = np.expand_dims(x[left_inds], axis=0)
right = np.expand_dims(x[right_inds], axis=0)
left_evts = np.expand_dims(evts[left_inds], axis=0)
right_evts = np.expand_dims(evts[right_inds], axis=0)
x = np.concatenate([left, right], axis=0)
evts = np.concatenate([left_evts, right_evts], axis=0)
return x, evts, fs
# Filter
class Filter(object):
def __init__(self, bands, fs, gpass=3, gstop=60, filttype="butter"):
self._fs = fs
self._filttype = filttype
self._gpass, self._gstop = gpass, gstop
if self._filttype != "notch":
if type(bands[0]) == list:
assert bands[0][0] > bands[1][0] and bands[0][1] < bands[1][1], \
f"Bandpass {bands[0]} Hz should be contained in bandstop {bands[1]} Hz."
self._wp = [f / (fs / 2) for f in bands[0]]
self._ws = [f / (fs / 2) for f in bands[1]]
self._btype = "bandpass"
else:
self._wp = bands[0] / (fs / 2)
self._ws = bands[1] / (fs / 2)
self._btype = "lowpass" if self._wp < self._ws else "highpass"
else:
self._f0 = bands[0] # central freq
self._Q = bands[1] #  Q = ( f0 / (fs / 2) ) / bw -> -3dB bandwidth
self._btype = "bandstop"
if self._filttype == "butter":
self._n, self._wn = signal.buttord(self._wp, self._ws, gpass, gstop)
self._ba = signal.butter(self._n, self._wn, btype=self._btype)
elif self._filttype == "ellip":
self._n, self._wn = signal.ellipord(self._wp, self._ws, gpass, gstop)
self._ba = signal.ellip(self._n, 1, 10, self._wn, btype=self._btype)
elif self._filttype == "notch":
self._ba = signal.iirnotch(self._f0, self._Q, self._fs)
else:
raise NotImplementedError(f"{self._filttype} not implemented")
if not np.all(np.abs(np.roots(self._ba[1])) < 1):
raise ArithmeticError(f"unstable filter, denominator roots bigger than 1")
def __repr__(self):
if self._filttype != "notch":
if type(self._wn) == np.ndarray or type(self._wn) == list:
fn = [w * (self._fs / 2) for w in self._wn]
fn = f"[{fn[0]: 2.2f}, {fn[1]: 2.2f}]"
else:
fn = self._wn * (self._fs / 2)
fn = f"{fn: 2.2f}"
s = f"Filter(type={self._filttype}-{self._btype}, order={self._n}, sfreq[Hz]={self._fs}, fn[Hz]={fn}, gpass[dB]={self._gpass}, gstop[dB]={self._gstop})\n"
else:
s = f"Filter(type={self._filttype}-{self._btype}, sfreq[Hz]={self._fs}, fn0[Hz]={self._f0: 2.2f}, Q[1]={self._Q})\n"
return s
def __call__(self, x, evts, fs, **kwargs):
assert fs == self._fs, \
f"Signal sfreq {fs} Hz different from process sfreq {self._fs} Hz"
x = signal.filtfilt(self._ba[0], self._ba[1], x)
if (np.isnan(x).any()):
raise ValueError("filter returned NaN values")
if (np.isinf(x).any()):
raise ValueError("filter returned Inf values")
return x, evts, fs
# Downsample to specified group fs
class Downsampler(object):
def __init__(self, in_fs, out_fs, axis=-1, threshold=13):
assert in_fs % out_fs == 0, \
f"input sfreq ({in_fs}Hz) should be divisible by output sfreq ({out_fs}Hz)"
self._in_fs = in_fs
self._out_fs = out_fs
self._axis = axis
self._threshold = threshold
def __repr__(self):
return f"Downsampler(sfreq_in[Hz]={self._in_fs}, sfreq_out[Hz]={self._out_fs}, axis={self._axis}, thres={self._threshold})\n"
def __call__(self, x, evts, fs, **kwargs):
assert self._in_fs == fs, \
f"x sfreq ({fs}Hz) different from expected sfreq ({self._in_fs}Hz)"
x = decimate(x, fs, self._out_fs, axis=self._axis, threshold=self._threshold)
return x, evts, self._out_fs
class VCconverter(object):
""" Converts force in V to Voluntary Contraction per unit.
Force in V needs to be filtered to remove offset.
Requires kwarg `mvc` with Maximum Voluntary Contraction in V.
"""
def __repr__(self):
return f"VCconverter()\n"
def __call__(self, x, evts, fs, **kwargs):
assert len(x.shape) == 4, "Input needs to be gathered by label"
for i in range(x.shape[0]):
x[i] = x[i] / kwargs["mvc"][i]
return x, evts, fs
# Compute HbO, HbR
class HbConverter(object):
""" Converts light intensity in Oxy and Deoxy Hemoglobin concentration.
Parameters may be introduced as a list in the same order and units as specified in `ds.attrs["nirs_..."]`
or as text in the same format.
Compared to other processes, this one takes 2 signal inputs and returns...
"""
def __init__(self, DPF, SD, exc, tbound=None):
self._DPF = self.__parse__(DPF)
self._SD = self.__parse__(SD)
self._exc = self.__parse__(exc)
self._tbound = tbound
@staticmethod
def __parse__(param):
if type(param) == np.bytes_:
param = param.decode("utf-8")
if type(param) == str:
param = param.split(",")
param = [p.split("=")[-1] for p in param]
param = [float(p) for p in param]
elif type(param) != list:
raise NotImplementedError(f"param has to be of type np.bytes_, str or list but is {type(param)}")
if len(param) == 2: # DPF
param = np.array(param).transpose()
elif len(param) == 4: # exc
param = np.reshape(np.array(param), (2, 2))
elif len(param) == 1: # SD
param = param[0]
return param
def __repr__(self):
return f"HbConverter(tbound[s]={self._tbound}, DPF[1]={self._DPF}, SD[cm]={self._SD}, exc[cm^-1 mol^-1]={self._exc.reshape(-1)})\n"
def __call__(self, x, evts, fs, **kwargs):
assert (x > 0).all(), "Optical intensity must be strictly positive."
# Converts to Optical Density (OD)
if self._tbound is not None:
ind = [int(t * fs) for t in self._tbound]
else:
ind = [0, -1]
x = -np.log10(x / np.mean(x[..., ind[0]:ind[1]], axis=-1, keepdims=True))
wl1, wl2 = np.split(x, 2, axis=-2)
x = []
# Apply Beer-Lambert law
C = np.linalg.pinv(self._SD * self._DPF * self._exc)
for i in range(wl1.shape[-2]):
x.append(
np.matmul(