[Mne_analysis] Spatiotemporal Cluster Analysis of Time-Frequency Data
Maryam Zolfaghar
Maryam.Zolfaghar at colorado.edu
Thu May 9 14:40:18 EDT 2019
External Email - Use Caution
Hi all,
I have been used *spatio_temporal_cluster_1samp_test* on my
time-frequency data (power (EpochsTFR) with a shape:[epochs X chan X freqs
X time] ). I have read the cited paper (MarisOostenveld07) for the
permutation tests. However, it is still not clear to me what these
functions are exactly doing. I have confused about the cluster analysis
after seeing my plots. I have attached my code here. I used a public
dataset (eegbci) as an example. My code could seem long, but most parts are
related to plotting. There are two main functions and I made more major
parts bold and with bigger font sizes.
More specifically, I am wondering if transforming participants' power
data to *EpochsTFR*" and "*AverageTFR*" is the correct way of analyzing
multiple participants? (e.g, after getting the power/ itc data for each
participant saving it in a multidimensional array and then transform
it to EpochsTFR
or AverageTFR).
Additionally, what do cluster_p_values tell us? Don't they show the
significance of one point in a power/itc data? So if I look at one point in
power/itc plot, its corresponding point in cluster_p_values plot should
tell me whether or not that point is significant?
I would be very thankful if someone can help me with these questions.
Thank you,
-Mary
"""
==============================================================================
*General explanation of the code:*
1. create all power and itc data (Function *create_power_itc_all_avgAll*)
2. For each frequency band, run the *spatio_temporal_cluster_1samp_test *
and plot the power and itc data with their corresponding
statistical cluster_p_values.
(function *stat_plot_power*)
==============================================================================
"""
iter_freqs = [('Low Freq.', 3, 30),
('High Freq.', 30, 50)]
power_all, itc_all, power_avgAll, itc_avgAll = *create_power_itc_all_avgAll*
()
for (band, fmin, fmax) in iter_freqs:
params = fmin, fmax, power_all, itc_all, power_avgAll, itc_avgAll
*stat_plot_power*(params)
#----------------------------------------------------------------------------------------------------------------------------------------
"""
==============================================================================
Related Functions
- *create_power_itc_all_avgAll*
- *stat_plot_power*
==============================================================================
"""
from mne.stats import spatio_temporal_cluster_1samp_test
from mne.viz.utils import center_cmap
from functools import partial
from mne.stats import ttest_1samp_no_p
import numpy as np
import matplotlib.pyplot as plt
import mne
from mne.datasets import eegbci
from mne.io import concatenate_raws, read_raw_edf
from mne.time_frequency import tfr_morlet
from numpy.random import randn
from mne.baseline import rescale
"""
==============================================================================
Func:
Output: Apply morlet and save the time-frequncy data
==============================================================================
"""
def *create_power_itc_all_avgAll*():
"""
==============================================================================
==============================================================================
Set params
==============================================================================
"""
selected_subj = list(range(1,20))
n_subj = len(selected_subj)
tmin_crop = -0.3;
tmax_crop = 0.3
#=========================================================================
# === just for initiating some params, I need to read one epoch to fill
them out
freqs = np.arange(3., 50., 1.)
n_cycles = freqs / 2.
#=========================================================================
n_freqs = len(freqs)
power_all = dict();itc_all = dict()#[subj * chan * freqs * time]
power_avgAll = dict(); itc_avgAll = dict()#[chan * freqs * time]
runs = [6, 10, 14] # use only hand and feet motor imagery runs
"""
==============================================================================
Read one subject to initialize some params
==============================================================================
"""
subject=1
fnames = eegbci.load_data(subject, runs)
raws = [read_raw_edf(f, preload=True) for f in fnames]
raw = concatenate_raws(raws)
raw.rename_channels(lambda x: x.strip('.')) # remove dots from channel
names
events, _ = mne.events_from_annotations(raw)
picks = mne.pick_channels(raw.info["ch_names"], ["C3", "Cz", "C4"])
# epoch data
##################################################################
tmin, tmax = -1.5, 2.5 # define epochs around events (in s)
event_ids = dict(hands=2, feet=3) # map event IDs to tasks
epochs = mne.Epochs(raw, events, event_ids, tmin, tmax,
picks=picks, baseline=None, preload=True)
n_epochs, n_chan, n_times = epochs.crop(tmin_crop,
tmax_crop).get_data().shape
power_all_subj = randn(n_subj, n_chan, n_freqs, n_times) * 0
itc_all_subj = randn(n_subj, n_chan, n_freqs, n_times) * 0
"""
==============================================================================
Read all subject
Apply tfr_morlet
==============================================================================
"""
subj_num_id=0
for subject in selected_subj:
fnames = eegbci.load_data(subject, runs)
raws = [read_raw_edf(f, preload=True) for f in fnames]
raw = concatenate_raws(raws)
raw.rename_channels(lambda x: x.strip('.')) # remove dots from
channel names
events, _ = mne.events_from_annotations(raw)
picks = mne.pick_channels(raw.info["ch_names"], ["C3", "Cz", "C4"])
# epoch data
##################################################################
epochs = mne.Epochs(raw, events, event_ids, tmin, tmax,
picks=picks, baseline=None, preload=True)
# Run TF decomposition overall epochs
tfr_pwr, tfr_itc = tfr_morlet(epochs, freqs=freqs,
n_cycles=n_cycles,return_itc=True, average=True)
power_all_subj[subj_num_id,:,:,:] = tfr_pwr.crop(tmin_crop,
tmax_crop).data
itc_all_subj[subj_num_id,:,:,:] = tfr_itc.crop(tmin_crop,
tmax_crop).data
info = tfr_pwr.crop(tmin_crop, tmax_crop).info
times = tfr_pwr.crop(tmin_crop, tmax_crop).times
subj_num_id+=1
# ---------------------------------------------------------------------
# [subj * chan * freqs * time] - EpochsTFR
power_all = mne.time_frequency.EpochsTFR(info, power_all_subj, times,
freqs)
itc_all = mne.time_frequency.EpochsTFR(info, itc_all_subj, times, freqs)
# -----------------------------------------------------------------
#all subjects, for one cond, one freq band
# [chan * freqs * time]- AverageTFR
nave = power_all.data.shape[0]
times = power_all.times
info = power_all.info
power_avg_subj = power_all.data.mean(axis=0)
power_avgAll = mne.time_frequency.AverageTFR(info, power_avg_subj,
times, freqs, nave)
itc_avg_subj = itc_all.data.mean(axis=0)
itc_avgAll = mne.time_frequency.AverageTFR(info, itc_avg_subj, times,
freqs, nave)
return power_all, itc_all, power_avgAll, itc_avgAll
"""
==============================================================================
Func:
Inputs: power and itc data
Output: apply stats and plot the results
==============================================================================
"""
def *stat_plot_power*(params):
"""
==============================================================================
==============================================================================
Set params
==============================================================================
"""
fmin, fmax, power_all, itc_all, power_avgAll, itc_avgAll = params
baseline = (-1, 0)
col_plt = num_block = 1; row_plt = 2
s11=25; s12=15; ii = 1 + num_block; bii=1
f_p, axs_p = plt.subplots(1,num_block, figsize=(s11, s12), sharex=True,
sharey=False)
times = power_avgAll.times
vmin_p = -0.5; vmax_p = 0.5;
"""
==============================================================================
Pick channels and frequencies of interest
Apply baseline correction
==============================================================================
"""
dp = np.mean(power_avgAll.data, axis = 0)
dp = dp[fmin:fmax,:]
dp = *rescale*(dp, times, baseline, mode='mean', copy=False)
"""
==============================================================================
*Normalize power values in [vmin_p, vmax_p] range*
==============================================================================
"""
data_power_diff = (((dp - np.min(dp)) * (vmax_p - (vmin_p))) /
(np.max(dp) - np.min(dp))) + (vmin_p)
"""
==============================================================================
Plot the pure power data
==============================================================================
"""
cmap = center_cmap(plt.cm.jet, vmin_p, vmax_p) # zero maps to white
fs2=12; lw = 2
labelsize_mj = 10; labelpad_size= 12
plt.subplot(row_plt, col_plt, bii)
plt.imshow(data_power_diff, vmin=vmin_p, vmax=vmax_p,
extent=[times[0], times[-1], fmin, fmax],
aspect='auto', origin='lower', cmap=cmap)
if bii==1:
plt.ylabel('Frequency (Hz)', labelpad=labelpad_size)
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.tick_params(axis = 'both', which = 'major', labelsize =
labelsize_mj)
plt.axvline(x=0, color='yellow', linewidth=lw)
plt.axvline(x=-0.05, color='gray', linewidth=lw)
plt.axvline(x=0.05, color='gray', linewidth=lw)
if fmax < 40:
plt.axhline(y=7, color='yellow', linewidth=lw)
plt.axhline(y=11, color='yellow', linewidth=lw)
plt.axhline(y=15, color='yellow', linewidth=lw)
f_p.subplots_adjust(right=0.9)
cbar_ax = f_p.add_axes([0.92, 0.55, 0.02, 0.3])
plt.colorbar(cax = cbar_ax)
"""
==============================================================================
* Apply statistical functions*
==============================================================================
"""
thresh_p_val = 0.05
n_permutations = 'all'
* tfce = dict(start=0, step=.02) *
*sigma = 1e-3 # sigma for the "hat" method*
* stat_fun_hat = partial(ttest_1samp_no_p, sigma=sigma)*
all_data_diff = power_all.data
all_data_diff = all_data_diff.data
all_data_diff = np.mean(all_data_diff, axis=0)
all_data_diff = all_data_diff[:,fmin:fmax,:]
X = all_data_diff
T_obs_i, clusters_i, *cluster_p_values_i*, H0_i = clu_i = \
*spatio_temporal_cluster_1samp_test*(X, tfce, n_permutations,
stat_fun=stat_fun_hat, tail=0)
"""
==============================================================================
* Plot the stats*
==============================================================================
"""
p_lims = [0, 0.01];
vmin_pval=p_lims[0]; vmax_pval=p_lims[1];
cmap = center_cmap(plt.cm.jet, vmin_pval, vmax_pval) # zero maps to
white
T_obs, clusters, *cluster_p_values* = T_obs_i, clusters_i,
cluster_p_values_i
T_obs_plot = np.nan * np.ones_like(T_obs)
for c, p_val in zip(clusters, *cluster_p_values*):
if p_val <= thresh_p_val:
print(p_val)
T_obs_plot[c] = -np.log10(p_val)
plt.subplot(row_plt, col_plt, ii)
plt.imshow(T_obs_plot, cmap=cmap,
extent=[times[0], times[-1], fmin, fmax],
aspect='auto', origin='lower', vmin=vmin_pval,
vmax=vmax_pval)
plt.xlabel('Time (ms)', fontsize=fs2)
plt.ylabel('Frequency (Hz)', fontsize=fs2)
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.tick_params(axis = 'both', which = 'major', labelsize =
labelsize_mj)
plt.axvline(x=0, color='yellow', linewidth=lw)
plt.axvline(x=-0.05, color='gray', linewidth=lw)
plt.axvline(x=0.05, color='gray', linewidth=lw)
if fmax < 40:
plt.axhline(y=7, color='yellow', linewidth=lw)
plt.axhline(y=11, color='yellow', linewidth=lw)
plt.axhline(y=15, color='yellow', linewidth=lw)
f_p.subplots_adjust(right=0.9)
cbar_ax = f_p.add_axes([0.92, 0.15, 0.02, 0.3])
plt.colorbar(cax=cbar_ax, shrink=0.75,
fraction=0.1, pad=0.025)
plt.clim(vmin_pval, vmax_pval)
ii+=1
bii+=1
f_p.savefig('fig_power.png')
-------------- next part --------------
An HTML attachment was scrubbed...
URL: http://mail.nmr.mgh.harvard.edu/pipermail/mne_analysis/attachments/20190509/8fe31baa/attachment-0001.html
More information about the Mne_analysis
mailing list