[Mne_analysis] Spatiotemporal Cluster Analysis of Time-Frequency Data

Maryam Zolfaghar Maryam.Zolfaghar at colorado.edu
Thu May 9 14:40:18 EDT 2019
Search archives:

        External Email - Use Caution        

Hi all,

I have been used *spatio_temporal_cluster_1samp_test* on my
time-frequency data (power (EpochsTFR) with a shape:[epochs X chan X freqs
X time] ). I have read the cited paper (MarisOostenveld07) for the
permutation tests. However, it is still not clear to me what these
functions are exactly doing. I have confused about the cluster analysis
after seeing my plots. I have attached my code here. I used a public
dataset (eegbci) as an example. My code could seem long, but most parts are
related to plotting. There are two main functions and I made more major
parts bold and with bigger font sizes.
   More specifically, I am wondering if transforming participants' power
data to *EpochsTFR*" and "*AverageTFR*" is the correct way of analyzing
multiple participants? (e.g, after getting the power/ itc data for each
participant  saving it in a multidimensional array and then transform
it to EpochsTFR
or AverageTFR).
    Additionally, what do cluster_p_values tell us? Don't they show the
significance of one point in a power/itc data? So if I look at one point in
power/itc plot, its corresponding point in cluster_p_values plot should
tell me whether or not that point is significant?

I would be very thankful if someone can help me with these questions.




Thank you,
-Mary




"""
 ==============================================================================
*General explanation of the code:*
1. create all power and itc data (Function *create_power_itc_all_avgAll*)
2. For each frequency band, run the *spatio_temporal_cluster_1samp_test *
    and plot  the power and itc data with their corresponding
statistical cluster_p_values.
(function *stat_plot_power*)

 ==============================================================================
"""

iter_freqs         = [('Low Freq.', 3, 30),
                            ('High Freq.', 30, 50)]
power_all, itc_all, power_avgAll, itc_avgAll = *create_power_itc_all_avgAll*
()


for (band, fmin, fmax) in iter_freqs:
    params = fmin, fmax, power_all, itc_all, power_avgAll, itc_avgAll

    *stat_plot_power*(params)

#----------------------------------------------------------------------------------------------------------------------------------------


"""
 ==============================================================================
Related Functions
  - *create_power_itc_all_avgAll*
  - *stat_plot_power*
==============================================================================
"""
from mne.stats import spatio_temporal_cluster_1samp_test
from mne.viz.utils import center_cmap
from functools import partial
from mne.stats import ttest_1samp_no_p
import numpy as np
import matplotlib.pyplot as plt
import mne
from mne.datasets import eegbci
from mne.io import concatenate_raws, read_raw_edf
from mne.time_frequency import tfr_morlet
from numpy.random import randn
from mne.baseline import rescale

"""
 ==============================================================================
 Func:
     Output: Apply morlet and save the time-frequncy data
 ==============================================================================
"""

def *create_power_itc_all_avgAll*():

    """

 ==============================================================================

 ==============================================================================
         Set params

 ==============================================================================
    """
    selected_subj = list(range(1,20))
    n_subj        = len(selected_subj)
    tmin_crop          = -0.3;
    tmax_crop          = 0.3


#=========================================================================
    # === just for initiating some params, I need to read one epoch to fill
them out
    freqs = np.arange(3., 50., 1.)
    n_cycles = freqs / 2.

#=========================================================================

    n_freqs      = len(freqs)
    power_all    = dict();itc_all     = dict()#[subj * chan * freqs * time]
    power_avgAll = dict(); itc_avgAll = dict()#[chan * freqs * time]

    runs = [6, 10, 14]  # use only hand and feet motor imagery runs

    """

==============================================================================
    Read one subject to initialize some params

==============================================================================
    """
    subject=1
    fnames = eegbci.load_data(subject, runs)
    raws = [read_raw_edf(f, preload=True) for f in fnames]
    raw = concatenate_raws(raws)
    raw.rename_channels(lambda x: x.strip('.'))  # remove dots from channel
names

    events, _ = mne.events_from_annotations(raw)

    picks = mne.pick_channels(raw.info["ch_names"], ["C3", "Cz", "C4"])
    # epoch data
##################################################################
    tmin, tmax = -1.5, 2.5  # define epochs around events (in s)
    event_ids = dict(hands=2, feet=3)  # map event IDs to tasks

    epochs = mne.Epochs(raw, events, event_ids, tmin, tmax,
        picks=picks, baseline=None, preload=True)

    n_epochs, n_chan, n_times = epochs.crop(tmin_crop,
tmax_crop).get_data().shape
    power_all_subj = randn(n_subj, n_chan, n_freqs, n_times) * 0
    itc_all_subj   = randn(n_subj, n_chan, n_freqs, n_times) * 0

    """

==============================================================================
    Read all subject
    Apply tfr_morlet

==============================================================================
    """
    subj_num_id=0

    for subject in selected_subj:
        fnames = eegbci.load_data(subject, runs)
        raws = [read_raw_edf(f, preload=True) for f in fnames]
        raw = concatenate_raws(raws)
        raw.rename_channels(lambda x: x.strip('.'))  # remove dots from
channel names
        events, _ = mne.events_from_annotations(raw)
        picks = mne.pick_channels(raw.info["ch_names"], ["C3", "Cz", "C4"])
        # epoch data
##################################################################
        epochs = mne.Epochs(raw, events, event_ids, tmin, tmax,
            picks=picks, baseline=None, preload=True)
        # Run TF decomposition overall epochs
        tfr_pwr, tfr_itc = tfr_morlet(epochs, freqs=freqs,
n_cycles=n_cycles,return_itc=True, average=True)
        power_all_subj[subj_num_id,:,:,:] = tfr_pwr.crop(tmin_crop,
tmax_crop).data
        itc_all_subj[subj_num_id,:,:,:]   = tfr_itc.crop(tmin_crop,
tmax_crop).data

        info  = tfr_pwr.crop(tmin_crop, tmax_crop).info
        times = tfr_pwr.crop(tmin_crop, tmax_crop).times

        subj_num_id+=1

    # ---------------------------------------------------------------------
    # [subj * chan * freqs * time] - EpochsTFR
    power_all = mne.time_frequency.EpochsTFR(info, power_all_subj, times,
freqs)
    itc_all = mne.time_frequency.EpochsTFR(info, itc_all_subj, times, freqs)
    # -----------------------------------------------------------------

    #all subjects, for one cond, one freq band
    # [chan * freqs * time]- AverageTFR
    nave = power_all.data.shape[0]
    times = power_all.times
    info = power_all.info

    power_avg_subj = power_all.data.mean(axis=0)
    power_avgAll = mne.time_frequency.AverageTFR(info, power_avg_subj,
times, freqs, nave)

    itc_avg_subj = itc_all.data.mean(axis=0)
    itc_avgAll = mne.time_frequency.AverageTFR(info, itc_avg_subj, times,
freqs, nave)

    return power_all, itc_all, power_avgAll, itc_avgAll

"""
 ==============================================================================
 Func:
     Inputs: power and itc data
     Output: apply stats and plot the results
 ==============================================================================
"""


def *stat_plot_power*(params):

    """

 ==============================================================================

 ==============================================================================
         Set params

 ==============================================================================
    """
    fmin, fmax, power_all, itc_all, power_avgAll, itc_avgAll = params
    baseline = (-1, 0)
    col_plt = num_block = 1; row_plt = 2
    s11=25; s12=15; ii = 1 + num_block; bii=1
    f_p, axs_p = plt.subplots(1,num_block, figsize=(s11, s12), sharex=True,
sharey=False)
    times = power_avgAll.times
    vmin_p = -0.5; vmax_p = 0.5;

    """

==============================================================================
    Pick channels and frequencies of interest
    Apply baseline correction

==============================================================================
    """
    dp = np.mean(power_avgAll.data, axis = 0)
    dp = dp[fmin:fmax,:]
    dp = *rescale*(dp, times, baseline, mode='mean', copy=False)

    """

==============================================================================
    *Normalize power values in [vmin_p, vmax_p] range*

==============================================================================
    """
    data_power_diff = (((dp - np.min(dp)) * (vmax_p - (vmin_p))) /
(np.max(dp) - np.min(dp))) + (vmin_p)


    """

==============================================================================
    Plot the pure power data

==============================================================================
    """
    cmap = center_cmap(plt.cm.jet, vmin_p, vmax_p)  # zero maps to white

    fs2=12; lw = 2
    labelsize_mj = 10; labelpad_size= 12


    plt.subplot(row_plt, col_plt, bii)
    plt.imshow(data_power_diff, vmin=vmin_p, vmax=vmax_p,
                    extent=[times[0], times[-1], fmin, fmax],
                    aspect='auto', origin='lower', cmap=cmap)

    if bii==1:
        plt.ylabel('Frequency (Hz)', labelpad=labelpad_size)

    plt.rcParams["font.weight"] = "bold"
    plt.rcParams["axes.labelweight"] = "bold"
    plt.tick_params(axis = 'both', which = 'major', labelsize =
labelsize_mj)
    plt.axvline(x=0, color='yellow', linewidth=lw)
    plt.axvline(x=-0.05, color='gray', linewidth=lw)
    plt.axvline(x=0.05, color='gray', linewidth=lw)

    if fmax < 40:
        plt.axhline(y=7, color='yellow', linewidth=lw)
        plt.axhline(y=11, color='yellow', linewidth=lw)
        plt.axhline(y=15, color='yellow', linewidth=lw)


    f_p.subplots_adjust(right=0.9)
    cbar_ax = f_p.add_axes([0.92, 0.55, 0.02, 0.3])
    plt.colorbar(cax = cbar_ax)


    """

 ==============================================================================
*     Apply statistical functions*

 ==============================================================================
    """
    thresh_p_val   = 0.05
    n_permutations = 'all'
*    tfce           = dict(start=0, step=.02) *
    *sigma = 1e-3  # sigma for the "hat" method*
*    stat_fun_hat = partial(ttest_1samp_no_p, sigma=sigma)*

    all_data_diff     = power_all.data
    all_data_diff = all_data_diff.data
    all_data_diff = np.mean(all_data_diff, axis=0)
    all_data_diff = all_data_diff[:,fmin:fmax,:]
    X = all_data_diff

    T_obs_i, clusters_i, *cluster_p_values_i*, H0_i = clu_i = \
    *spatio_temporal_cluster_1samp_test*(X, tfce, n_permutations,
stat_fun=stat_fun_hat, tail=0)

    """

 ==============================================================================
*     Plot the stats*

 ==============================================================================
    """
    p_lims = [0, 0.01];
    vmin_pval=p_lims[0]; vmax_pval=p_lims[1];
    cmap = center_cmap(plt.cm.jet, vmin_pval, vmax_pval)  # zero maps to
white

    T_obs, clusters, *cluster_p_values* = T_obs_i, clusters_i,
cluster_p_values_i
    T_obs_plot = np.nan * np.ones_like(T_obs)

    for c, p_val in zip(clusters, *cluster_p_values*):
        if p_val <= thresh_p_val:
            print(p_val)
            T_obs_plot[c] = -np.log10(p_val)


    plt.subplot(row_plt, col_plt, ii)
    plt.imshow(T_obs_plot, cmap=cmap,
               extent=[times[0], times[-1], fmin, fmax],
               aspect='auto', origin='lower', vmin=vmin_pval,
vmax=vmax_pval)

    plt.xlabel('Time (ms)', fontsize=fs2)
    plt.ylabel('Frequency (Hz)', fontsize=fs2)
    plt.rcParams["font.weight"] = "bold"
    plt.rcParams["axes.labelweight"] = "bold"
    plt.tick_params(axis = 'both', which = 'major', labelsize =
labelsize_mj)
    plt.axvline(x=0, color='yellow', linewidth=lw)
    plt.axvline(x=-0.05, color='gray', linewidth=lw)
    plt.axvline(x=0.05, color='gray', linewidth=lw)

    if fmax < 40:
        plt.axhline(y=7, color='yellow', linewidth=lw)
        plt.axhline(y=11, color='yellow', linewidth=lw)
        plt.axhline(y=15, color='yellow', linewidth=lw)

    f_p.subplots_adjust(right=0.9)
    cbar_ax = f_p.add_axes([0.92, 0.15, 0.02, 0.3])
    plt.colorbar(cax=cbar_ax, shrink=0.75,
                fraction=0.1, pad=0.025)
    plt.clim(vmin_pval, vmax_pval)

    ii+=1
    bii+=1

f_p.savefig('fig_power.png')
-------------- next part --------------
An HTML attachment was scrubbed...
URL: http://mail.nmr.mgh.harvard.edu/pipermail/mne_analysis/attachments/20190509/8fe31baa/attachment-0001.html 


More information about the Mne_analysis mailing list