Source code for torch_specinv.methods

from .metrics import *
import torch
import torch.nn as nn
import torch.fft as fft
import torch.nn.functional as F
from torch.optim.lbfgs import LBFGS
from tqdm import tqdm
from functools import partial
from typing import Tuple
import math

pi2 = 2 * math.pi

_func_mapper = {
    'SC': sc,
    'SNR': snr,
    'SER': ser
}


def _args_helper(spec, **stft_kwargs):
    """A helper function to get stft arguments from the provided kwargs.

    Args:
        spec: The magnitude spectrum of size (*, freq, time).
        **stft_kwargs: Keyword arguments that computed spec from 'torch.stft'.
        See `torch.stft` for details.

    Returns:
        n_fft: FFT size of the spectrum.
        processed_kwargs: Dict object that stored the processed keyword arguments.

    """
    args_dict = {'win_length': None,
                 'window': None,
                 'hop_length': None,
                 'center': True,
                 'pad_mode': 'reflect',
                 'normalized': False,
                 'onesided': None,
                 'return_complex': None}
    for key, item in args_dict.items():
        try:
            args_dict[key] = stft_kwargs[key]
        except:
            pass
    win_length, window, hop_length, center, pad_mode, normalized, onesided, return_complex = tuple(
        args_dict.values())

    device = spec.device
    dtype = spec.dtype
    if dtype == torch.complex32:
        dtype = torch.float16
    elif dtype == torch.complex64:
        dtype = torch.float32
    elif dtype == torch.complex128:
        dtype = torch.float64

    if onesided is None:
        if window is not None and window.is_complex():
            onesided = False
        else:
            onesided = True

    if onesided:
        n_fft = (spec.shape[-2] - 1) * 2
    else:
        n_fft = spec.shape[-2]

    if not win_length:
        win_length = n_fft

    if not hop_length:
        hop_length = n_fft // 4

    if window is None:
        window = torch.ones(win_length, dtype=dtype, device=device)

    assert n_fft >= win_length
    if n_fft > win_length:
        window = F.pad(window, [(n_fft - win_length) //
                                2, (n_fft - win_length + 1) // 2])
        win_length = n_fft

    args_dict['win_length'] = win_length
    args_dict['hop_length'] = hop_length
    args_dict['window'] = window
    args_dict['return_complex'] = True
    args_dict['onesided'] = onesided

    return n_fft, args_dict


def _get_ola_weight(window):
    ola_weight = torch.diag(window).unsqueeze(1)
    return ola_weight


def _spec_formatter(spec, **stft_kwargs):
    shape = spec.shape
    assert 4 > len(shape) > 1
    if len(shape) == 2:
        spec = spec.unsqueeze(0)

    if not spec.is_complex():
        cmplx_spec = phase_init(spec, **stft_kwargs)
        target_spec = spec
    else:
        cmplx_spec = spec
        target_spec = spec.abs()
    return cmplx_spec, target_spec


def _ola(x, hop_length, weight, padding, norm_envelope=None):
    """A helper function to do overlap-and-add.

    Args:
        x: input tensor of size :math: '(batch, window_size, time)'.
        hop_length: The distance between neighboring sliding window frames.
        weight: An identity matrix of size (win_length x win_length) .
        norm_envelope: The normalized coefficient apply on synthesis window.

    Returns:
        A 1d tensor containing the overlap-and-add result.

    """
    ola_x = F.conv_transpose1d(
        x, weight, stride=hop_length, padding=padding).squeeze(1)
    if norm_envelope is None:
        norm_envelope = F.conv_transpose1d(torch.ones_like(
            x[:1]), weight * weight, stride=hop_length, padding=padding).squeeze()
    return ola_x / norm_envelope, norm_envelope


def _istft(x, n_fft, ola_weight,
           win_length, window, hop_length, center, normalized, onesided, pad_mode, return_complex,
           norm_envelope=None):
    """
    A helper function to do istft.
    """
    if onesided:
        x = fft.irfft(x, n=n_fft, dim=-2,
                      norm='ortho' if normalized else 'backward')
    else:
        x = fft.ifft(x, n=n_fft, dim=-2,
                     norm='ortho' if normalized else 'backward').real

    x, norm_envelope = _ola(x, hop_length, ola_weight, padding=n_fft // 2 if center else 0,
                            norm_envelope=norm_envelope)
    return x, norm_envelope


def _training_loop(
        closure,
        status_dict,
        target,
        max_iter,
        tol,
        verbose,
        eva_iter,
        metric,
):
    assert eva_iter > 0
    assert max_iter > 0
    assert tol >= 0

    metric = metric.upper()
    assert metric.upper() in _func_mapper.keys()

    bar_dict = {}
    bar_dict[metric] = 0
    metric_func = _func_mapper[metric]

    criterion = F.mse_loss
    init_loss = None

    with tqdm(total=max_iter, disable=not verbose) as pbar:
        for i in range(max_iter):
            output = closure(status_dict)
            if i % eva_iter == eva_iter - 1:
                bar_dict[metric] = metric_func(output, target).item()
                l2_loss = criterion(output, target).item()
                pbar.set_postfix(**bar_dict, loss=l2_loss)
                pbar.update(eva_iter)

                if not init_loss:
                    init_loss = l2_loss
                elif (previous_loss - l2_loss) / init_loss < tol and previous_loss > l2_loss:
                    break
                previous_loss = l2_loss


[docs]def griffin_lim(spec, max_iter=200, tol=1e-6, alpha=0.99, verbose=True, eva_iter=10, metric='sc', **stft_kwargs): r"""Reconstruct spectrogram phase using the will known `Griffin-Lim`_ algorithm and its variation, `Fast Griffin-Lim`_. .. _`Griffin-Lim`: https://pdfs.semanticscholar.org/14bc/876fae55faf5669beb01667a4f3bd324a4f1.pdf .. _`Fast Griffin-Lim`: https://perraudin.info/publications/perraudin-note-002.pdf Args: spec (Tensor): the input tensor of size :math:`(N \times T)` (magnitude) or :math:`(N \times T \times 2)` (complex input). If a magnitude spectrogram is given, the phase will first be intialized using :func:`torch_specinv.methods.phase_init`; otherwise start from the complex input. max_iter (int): maximum number of iterations before timing out. tol (float): tolerance of the stopping condition base on L2 loss. Default: ``1e-6`` alpha (float): speedup parameter used in `Fast Griffin-Lim`_, set it to zero will disable it. Default: ``0`` verbose (bool): whether to be verbose. Default: :obj:`True` eva_iter (int): steps size for evaluation. After each step, the function defined in `metric` will evaluate. Default: ``10`` metric (str): evaluation function. Currently available functions: ``'sc'`` (spectral convergence), ``'snr'`` or ``'ser'``. Default: ``'sc'`` **stft_kwargs: other arguments that pass to :func:`torch.stft` Returns: A 1d tensor converted from the given spectrogram """ assert alpha >= 0 cmplx_spec, target_spec = _spec_formatter(spec, **stft_kwargs) n_fft, processed_args = _args_helper(target_spec, **stft_kwargs) ola_weight = _get_ola_weight(processed_args['window']) istft = partial(_istft, n_fft=n_fft, ola_weight=ola_weight, **processed_args) pre_spec = cmplx_spec.clone() x, norm_envelope = istft(cmplx_spec) lr = alpha / (1 + alpha) def closure(status_dict): x = status_dict['x'] pre_spec = status_dict['pre_spec'] new_spec = torch.stft(x, n_fft, **processed_args) output = new_spec.abs() new_spec = new_spec - pre_spec * lr status_dict['pre_spec'] = new_spec norm = new_spec.abs().add_(1e-16) new_spec = new_spec * target_spec / norm x, _ = istft(new_spec, norm_envelope=norm_envelope) status_dict['x'] = x return output stats = { 'x': x, 'pre_spec': pre_spec } _training_loop( closure, stats, target_spec, max_iter, tol, verbose, eva_iter, metric ) x = stats['x'] if not (spec.shape[0] == 1 and len(spec.shape) == 3): x = x.squeeze(0) return x
[docs]def RTISI_LA(spec, look_ahead=-1, asymmetric_window=False, max_iter=25, alpha=0.99, verbose=1, **stft_kwargs): r""" Reconstruct spectrogram phase using `Real-Time Iterative Spectrogram Inversion with Look Ahead`_ (RTISI-LA). .. _`Real-Time Iterative Spectrogram Inversion with Look Ahead`: https://lonce.org/home/Publications/publications/2007_RealtimeSignalReconstruction.pdf Args: spec (Tensor): the input tensor of size :math:`(N \times T)` (magnitude). look_ahead (int): how many future frames will be consider. ``-1`` will set it to ``(win_length - 1) / hop_length``, ``0`` will disable look-ahead strategy and fall back to original RTISI algorithm. Default: ``-1`` asymmetric_window (bool): whether to apply asymmetric window on the first iteration for new coming frame. max_iter (int): number of iterations for each step. alpha (float): speedup parameter used in `Fast Griffin-Lim`_, set it to zero will disable it. Default: ``0`` verbose (bool): whether to be verbose. Default: :obj:`True` **stft_kwargs: other arguments that pass to :func:`torch.stft`. Returns: A 1d tensor converted from the given spectrogram """ assert max_iter > 0 assert alpha >= 0 assert not spec.is_complex() shape = spec.shape assert 4 > len(shape) > 1 target_spec = spec if len(shape) == 2: target_spec = target_spec.unsqueeze(0) n_fft, processed_args = _args_helper(target_spec, **stft_kwargs) ola_weight = _get_ola_weight(processed_args['window']) copyed_kwargs = stft_kwargs.copy() copyed_kwargs['center'] = False copyed_kwargs['return_complex'] = True win_length = processed_args['win_length'] hop_length = processed_args['hop_length'] onesided = processed_args['onesided'] normalized = processed_args['normalized'] window = processed_args['window'] synth_coeff = hop_length / (window @ window) # ola_weight = ola_weight * synth_coeff num_keep = (win_length - 1) // hop_length if look_ahead < 0: look_ahead = num_keep asym_window1 = target_spec.new_zeros(win_length) for i in range(num_keep): asym_window1[(i + 1) * hop_length:] += window.flip(0)[:- (i + 1) * hop_length] asym_window1 *= synth_coeff asym_window2 = target_spec.new_zeros(win_length) for i in range(num_keep + 1): asym_window2[i * hop_length:] += window.flip(0)[:-i * hop_length if i else None] asym_window2 *= synth_coeff steps = target_spec.shape[2] target_spec = F.pad(target_spec, [look_ahead, look_ahead]) if onesided: irfft = partial(fft.irfft, n=n_fft, dim=-2, norm='ortho' if normalized else 'backward') rfft = partial(fft.rfft, n=n_fft, dim=-2, norm='ortho' if normalized else 'backward') else: def irfft(x): return fft.ifft(x, n=n_fft, dim=-2, norm='ortho' if normalized else 'backward').real def rfft(x): return fft.fft(x, n=n_fft, dim=-2, norm='ortho' if normalized else 'backward') # initialize first frame with zero phase first_frame = target_spec[..., look_ahead, None] keeped_chunk = target_spec.new_zeros(target_spec.shape[0], n_fft, num_keep) update_chunk = target_spec.new_zeros( target_spec.shape[0], n_fft, look_ahead) update_chunk = torch.cat((update_chunk, irfft(first_frame + 0j)), 2) lr = alpha / (1 + alpha) output_xt_list = [] with tqdm(total=steps + look_ahead, disable=not verbose) as pbar: for i in range(steps + look_ahead): for j in range(max_iter): x, _ = _ola(torch.cat((keeped_chunk, update_chunk), 2), hop_length, ola_weight * synth_coeff, padding=0, norm_envelope=1) x = x[:, num_keep * hop_length:] if asymmetric_window: xt_winview = x.unfold( 1, win_length, hop_length).transpose(1, 2) xt_norm_wind = xt_winview[:, :, :-1] * window[:, None] if j: xt_asym_wind = xt_winview[:, :, -1:] * asym_window2[:, None] else: xt_asym_wind = xt_winview[:, :, -1:] * asym_window1[:, None] xt = torch.cat((xt_norm_wind, xt_asym_wind), 2) new_spec = rfft(xt) else: new_spec = torch.stft(x, n_fft=n_fft, **copyed_kwargs) if j: new_spec = new_spec - lr * pre_spec elif i: new_spec = torch.cat( (new_spec[:, :, :-1] - lr * pre_spec[:, :, 1:], new_spec[:, :, -1:]), 2) pre_spec = new_spec norm = new_spec.abs() + 1e-16 new_spec = new_spec * \ target_spec[..., i:i + look_ahead + 1] / norm update_chunk = irfft(new_spec) pbar.update() output_xt_list.append(update_chunk[:, :, 0]) keeped_chunk = torch.cat( (keeped_chunk[:, :, 1:], update_chunk[:, :, :1]), 2) update_chunk = F.pad(update_chunk[:, :, 1:], [0, 1]) all_xt = torch.stack(output_xt_list[look_ahead if look_ahead else 0:], 2) x, _ = _ola(all_xt, hop_length, ola_weight, padding=win_length // 2 if processed_args['center'] else 0) if not (spec.shape[0] == 1 and len(spec.shape) == 3): x = x.squeeze(0) return x
[docs]def ADMM(spec, max_iter=1000, tol=1e-6, rho=0.1, verbose=1, eva_iter=10, metric='sc', **stft_kwargs): r""" Reconstruct spectrogram phase using `Griffin–Lim Like Phase Recovery via Alternating Direction Method of Multipliers`_ . .. _`Griffin–Lim Like Phase Recovery via Alternating Direction Method of Multipliers`: https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=8552369 Args: spec (Tensor): the input tensor of size :math:`(N \times T)` (magnitude) or :math:`(N \times T \times 2)` (complex input). If a magnitude spectrogram is given, the phase will first be intialized using :func:`torch_specinv.methods.phase_init`; otherwise start from the complex input. max_iter (int): maximum number of iterations before timing out. tol (float): tolerance of the stopping condition base on L2 loss. Default: ``1e-6`` rho (float): non-negative speedup parameter. Small value is preferable when the input spectrogram is noisy (inperfect); set it to 1 will behave similar to ``griffin_lim``. Default: ``0.1`` verbose (bool): whether to be verbose. Default: :obj:`True` eva_iter (int): steps size for evaluation. After each step, the function defined in ``metric`` will evaluate. Default: ``10`` metric (str): evaluation function. Currently available functions: ``'sc'`` (spectral convergence), ``'snr'`` or ``'ser'``. Default: ``'sc'`` **stft_kwargs: other arguments that pass to :func:`torch.stft`. Returns: A 1d tensor converted from the given spectrogram """ assert eva_iter > 0 assert max_iter > 0 assert tol >= 0 assert metric.upper() in list(_func_mapper.keys()) cmplx_spec, target_spec = _spec_formatter(spec, **stft_kwargs) n_fft, processed_args = _args_helper(target_spec, **stft_kwargs) ola_weight = _get_ola_weight(processed_args['window']) istft = partial(_istft, n_fft=n_fft, ola_weight=ola_weight, **processed_args) X = cmplx_spec x, norm_envelope = istft(X) Z = X.clone() Y = X.clone() U = torch.zeros_like(X) def closure(status_dict): X = status_dict['X'] Y = status_dict['Y'] U = status_dict['U'] x = status_dict['x'] reconstruted = torch.stft(x, n_fft, **processed_args) output = reconstruted.abs() Z = (rho * Y + reconstruted) / (1 + rho) U = U + X - Z # Pc2 X = Z - U norm = X.abs() + 1e-16 X = X * target_spec / norm Y = X + U # Pc1 x, _ = istft(Y, norm_envelope=norm_envelope) status_dict['Y'] = Y status_dict['X'] = X status_dict['U'] = U status_dict['x'] = x return output stats = { 'Y': Y, 'U': U, 'X': X, 'x': x } _training_loop( closure, stats, target_spec, max_iter, tol, verbose, eva_iter, metric ) x = stats['x'] if not (spec.shape[0] == 1 and len(spec.shape) == 3): x = x.squeeze(0) return x
[docs]def L_BFGS(spec, transform_fn, samples=None, init_x0=None, outer_max_iter=1000, tol=1e-6, verbose=1, eva_iter=10, metric='sc', **kwargs): r""" Reconstruct spectrogram phase using `Inversion of Auditory Spectrograms, Traditional Spectrograms, and Other Envelope Representations`_, where I directly use the :class:`torch.optim.LBFGS` optimizer provided in PyTorch. This method doesn't restrict to traditional short-time Fourier Transform, but any kinds of presentation (ex: Mel-scaled Spectrogram) as long as the transform function is differentiable. .. _`Inversion of Auditory Spectrograms, Traditional Spectrograms, and Other Envelope Representations`: https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=6949659 Args: spec (Tensor): the input presentation. transform_fn: a function that has the form ``spec = transform_fn(x)`` where x is an 1d tensor. samples (int, optional): number of samples in time domain. Default: :obj:`None` init_x0 (Tensor, optional): an 1d tensor that make use as initial time domain samples. If not provided, will use random value tensor with length equal to ``samples``. outer_max_iter (int): maximum number of iterations before timing out. tol (float): tolerance of the stopping condition base on L2 loss. Default: ``1e-6``. verbose (bool): whether to be verbose. Default: :obj:`True` eva_iter (int): steps size for evaluation. After each step, the function defined in ``metric`` will evaluate. Default: ``10`` metric (str): evaluation function. Currently available functions: ``'sc'`` (spectral convergence), ``'snr'`` or ``'ser'``. Default: ``'sc'`` **kwargs: other arguments that pass to :class:`torch.optim.LBFGS`. Returns: A 1d tensor converted from the given presentation """ if init_x0 is None: init_x0 = spec.new_empty(*samples).normal_(std=1e-6) x = nn.Parameter(init_x0) T = spec criterion = nn.MSELoss() optimizer = LBFGS([x], **kwargs) def inner_closure(): optimizer.zero_grad() V = transform_fn(x) loss = criterion(V, T) loss.backward() return loss def outer_closure(status_dict): optimizer.step(inner_closure) with torch.no_grad(): V = transform_fn(x) return V _training_loop( outer_closure, {}, T, outer_max_iter, tol, verbose, eva_iter, metric ) return x.detach()
[docs]def phase_init(spec, **stft_kwargs): r""" A phase initialize function that can be seen as a simplified version of `Single Pass Spectrogram Inversion`_. .. _`Single Pass Spectrogram Inversion`: https://ieeexplore.ieee.org/document/7251907 Args: spec (Tensor): the input tensor of size :math:`(* \times N \times T)` (magnitude). **stft_kwargs: other arguments that pass to :func:`torch.stft` Returns: The estimated complex value spectrogram of size :math:`(N \times T \times 2)` """ assert not spec.is_complex() shape = spec.shape if len(spec.shape) == 2: spec = spec.unsqueeze(0) assert len(spec.shape) == 3 n_fft, processed_args = _args_helper(spec, **stft_kwargs) hop_length = processed_args['hop_length'] phase = torch.zeros_like(spec) mask = (spec[:, 1:-1] > spec[:, 2:]) & (spec[:, 1:-1] > spec[:, :-2]) mask = F.pad(mask, [0, 0, 1, 1]) b = torch.masked_select(spec, mask) a = torch.masked_select(spec[:, :-1], mask[:, 1:]) r = torch.masked_select(spec[:, 1:], mask[:, :-1]) idx1, idx2, idx3 = torch.nonzero(mask).t().unbind() p = 0.5 * (a - r) / (a - 2 * b + r) omega = pi2 * (idx2.float() + p) / n_fft * hop_length phase[idx1, idx2, idx3] = omega phase[idx1, idx2 - 1, idx3] = omega phase[idx1, idx2 + 1, idx3] = omega phase = torch.cumsum(phase, 2) angle = torch.exp(phase * 1j) spec = spec * angle return spec.view(shape)