Update param_init_fns.py
Browse files- param_init_fns.py +49 -51
param_init_fns.py
CHANGED
|
@@ -2,22 +2,26 @@ import math
|
|
| 2 |
import warnings
|
| 3 |
from collections.abc import Sequence
|
| 4 |
from functools import partial
|
| 5 |
-
from typing import Optional, Tuple, Union
|
| 6 |
import torch
|
| 7 |
from torch import nn
|
|
|
|
| 8 |
from .norm import NORM_CLASS_REGISTRY
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
def torch_default_param_init_fn_(module: nn.Module,
|
| 11 |
del kwargs
|
| 12 |
-
if
|
| 13 |
-
warnings.warn(f"Initializing network using module's reset_parameters attribute")
|
| 14 |
-
if hasattr(module, 'reset_parameters'):
|
| 15 |
module.reset_parameters()
|
| 16 |
|
| 17 |
-
def fused_init_helper_(module: nn.Module, init_fn_):
|
| 18 |
_fused = getattr(module, '_fused', None)
|
| 19 |
if _fused is None:
|
| 20 |
raise RuntimeError(f'Internal logic error')
|
|
|
|
| 21 |
(dim, splits) = _fused
|
| 22 |
splits = (0, *splits, module.weight.size(dim))
|
| 23 |
for (s, e) in zip(splits[:-1], splits[1:]):
|
|
@@ -25,10 +29,8 @@ def fused_init_helper_(module: nn.Module, init_fn_):
|
|
| 25 |
slice_indices[dim] = slice(s, e)
|
| 26 |
init_fn_(module.weight[slice_indices])
|
| 27 |
|
| 28 |
-
def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None,
|
| 29 |
del kwargs
|
| 30 |
-
if verbose > 1:
|
| 31 |
-
warnings.warn(f'If model has bias parameters they are initialized to 0.')
|
| 32 |
init_div_is_residual = init_div_is_residual
|
| 33 |
if init_div_is_residual is False:
|
| 34 |
div_is_residual = 1.0
|
|
@@ -36,20 +38,18 @@ def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model:
|
|
| 36 |
div_is_residual = math.sqrt(2 * n_layers)
|
| 37 |
elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int):
|
| 38 |
div_is_residual = init_div_is_residual
|
| 39 |
-
elif
|
| 40 |
div_is_residual = float(init_div_is_residual)
|
| 41 |
else:
|
| 42 |
div_is_residual = 1.0
|
| 43 |
raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}')
|
| 44 |
-
if
|
| 45 |
-
if verbose > 1:
|
| 46 |
-
warnings.warn(f'Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. ' + f'Set `init_div_is_residual: false` in init config to disable this.')
|
| 47 |
-
if isinstance(module, nn.Linear):
|
| 48 |
if hasattr(module, '_fused'):
|
| 49 |
fused_init_helper_(module, init_fn_)
|
| 50 |
else:
|
| 51 |
init_fn_(module.weight)
|
| 52 |
if module.bias is not None:
|
|
|
|
| 53 |
torch.nn.init.zeros_(module.bias)
|
| 54 |
if init_div_is_residual is not False and getattr(module, '_is_residual', False):
|
| 55 |
with torch.no_grad():
|
|
@@ -60,8 +60,6 @@ def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model:
|
|
| 60 |
if std == 0:
|
| 61 |
warnings.warn(f'Embedding layer initialized to 0.')
|
| 62 |
emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
|
| 63 |
-
if verbose > 1:
|
| 64 |
-
warnings.warn(f'Embedding layer initialized using normal distribution with mean=0 and std={std!r}.')
|
| 65 |
elif emb_init_uniform_lim is not None:
|
| 66 |
lim = emb_init_uniform_lim
|
| 67 |
if isinstance(lim, Sequence):
|
|
@@ -75,17 +73,13 @@ def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model:
|
|
| 75 |
lim = [-lim, lim]
|
| 76 |
(a, b) = lim
|
| 77 |
emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
|
| 78 |
-
if verbose > 1:
|
| 79 |
-
warnings.warn(f'Embedding layer initialized using uniform distribution in range {lim}.')
|
| 80 |
else:
|
| 81 |
emb_init_fn_ = init_fn_
|
| 82 |
emb_init_fn_(module.weight)
|
| 83 |
elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
|
| 84 |
-
if
|
| 85 |
-
warnings.warn(f'Norm weights are set to 1. If norm layer has a bias it is initialized to 0.')
|
| 86 |
-
if hasattr(module, 'weight') and module.weight is not None:
|
| 87 |
torch.nn.init.ones_(module.weight)
|
| 88 |
-
if hasattr(module, 'bias') and module.bias
|
| 89 |
torch.nn.init.zeros_(module.bias)
|
| 90 |
elif isinstance(module, nn.MultiheadAttention):
|
| 91 |
if module._qkv_same_embed_dim:
|
|
@@ -114,32 +108,45 @@ def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model:
|
|
| 114 |
module.out_proj.weight.div_(div_is_residual)
|
| 115 |
if module.out_proj.bias is not None:
|
| 116 |
torch.nn.init.zeros_(module.out_proj.bias)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
else:
|
| 118 |
for _ in module.parameters(recurse=False):
|
| 119 |
raise NotImplementedError(f'{module.__class__.__name__} parameters are not initialized by param_init_fn.')
|
| 120 |
|
| 121 |
-
def _normal_init_(std, mean=0.0):
|
| 122 |
return partial(torch.nn.init.normal_, mean=mean, std=std)
|
| 123 |
|
| 124 |
-
def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None,
|
| 125 |
del kwargs
|
| 126 |
init_fn_ = _normal_init_(std=std)
|
| 127 |
-
|
| 128 |
-
warnings.warn(f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}')
|
| 129 |
-
generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
|
| 130 |
|
| 131 |
-
def baseline_param_init_fn_(module: nn.Module, init_std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None,
|
| 132 |
del kwargs
|
| 133 |
if init_std is None:
|
| 134 |
raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.")
|
| 135 |
-
_normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim
|
| 136 |
|
| 137 |
-
def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None,
|
| 138 |
del kwargs
|
| 139 |
std = math.sqrt(2 / (5 * d_model))
|
| 140 |
-
_normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim
|
| 141 |
|
| 142 |
-
def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None,
|
| 143 |
"""From section 2.3.1 of GPT-NeoX-20B:
|
| 144 |
|
| 145 |
An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
|
|
@@ -148,34 +155,25 @@ def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init
|
|
| 148 |
"""
|
| 149 |
del kwargs
|
| 150 |
residual_div = n_layers / math.sqrt(10)
|
| 151 |
-
|
| 152 |
-
warnings.warn(f'setting init_div_is_residual to {residual_div}')
|
| 153 |
-
small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
|
| 154 |
|
| 155 |
-
def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu',
|
| 156 |
del kwargs
|
| 157 |
-
if verbose > 1:
|
| 158 |
-
warnings.warn(f'Using nn.init.kaiming_uniform_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}')
|
| 159 |
kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
|
| 160 |
-
generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim
|
| 161 |
|
| 162 |
-
def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu',
|
| 163 |
del kwargs
|
| 164 |
-
if verbose > 1:
|
| 165 |
-
warnings.warn(f'Using nn.init.kaiming_normal_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}')
|
| 166 |
kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
|
| 167 |
-
generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim
|
| 168 |
|
| 169 |
-
def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0,
|
| 170 |
del kwargs
|
| 171 |
xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
|
| 172 |
-
|
| 173 |
-
warnings.warn(f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' + f'gain={init_gain}')
|
| 174 |
-
generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
|
| 175 |
|
| 176 |
-
def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0,
|
|
|
|
| 177 |
xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
|
| 178 |
-
|
| 179 |
-
warnings.warn(f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' + f'gain={init_gain}')
|
| 180 |
-
generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
|
| 181 |
MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_}
|
|
|
|
| 2 |
import warnings
|
| 3 |
from collections.abc import Sequence
|
| 4 |
from functools import partial
|
| 5 |
+
from typing import Any, Callable, Optional, Tuple, Union
|
| 6 |
import torch
|
| 7 |
from torch import nn
|
| 8 |
+
from .fc import FC_CLASS_REGISTRY
|
| 9 |
from .norm import NORM_CLASS_REGISTRY
|
| 10 |
+
try:
|
| 11 |
+
import transformer_engine.pytorch as te
|
| 12 |
+
except:
|
| 13 |
+
te = None
|
| 14 |
|
| 15 |
+
def torch_default_param_init_fn_(module: nn.Module, **kwargs: Any) -> None:
|
| 16 |
del kwargs
|
| 17 |
+
if hasattr(module, 'reset_parameters') and isinstance(module.reset_parameters, Callable):
|
|
|
|
|
|
|
| 18 |
module.reset_parameters()
|
| 19 |
|
| 20 |
+
def fused_init_helper_(module: nn.Module, init_fn_: Callable) -> None:
|
| 21 |
_fused = getattr(module, '_fused', None)
|
| 22 |
if _fused is None:
|
| 23 |
raise RuntimeError(f'Internal logic error')
|
| 24 |
+
assert isinstance(module.weight, torch.Tensor)
|
| 25 |
(dim, splits) = _fused
|
| 26 |
splits = (0, *splits, module.weight.size(dim))
|
| 27 |
for (s, e) in zip(splits[:-1], splits[1:]):
|
|
|
|
| 29 |
slice_indices[dim] = slice(s, e)
|
| 30 |
init_fn_(module.weight[slice_indices])
|
| 31 |
|
| 32 |
+
def generic_param_init_fn_(module: nn.Module, init_fn_: Callable, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
|
| 33 |
del kwargs
|
|
|
|
|
|
|
| 34 |
init_div_is_residual = init_div_is_residual
|
| 35 |
if init_div_is_residual is False:
|
| 36 |
div_is_residual = 1.0
|
|
|
|
| 38 |
div_is_residual = math.sqrt(2 * n_layers)
|
| 39 |
elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int):
|
| 40 |
div_is_residual = init_div_is_residual
|
| 41 |
+
elif init_div_is_residual.isnumeric():
|
| 42 |
div_is_residual = float(init_div_is_residual)
|
| 43 |
else:
|
| 44 |
div_is_residual = 1.0
|
| 45 |
raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}')
|
| 46 |
+
if isinstance(module, tuple(set(FC_CLASS_REGISTRY.values()))):
|
|
|
|
|
|
|
|
|
|
| 47 |
if hasattr(module, '_fused'):
|
| 48 |
fused_init_helper_(module, init_fn_)
|
| 49 |
else:
|
| 50 |
init_fn_(module.weight)
|
| 51 |
if module.bias is not None:
|
| 52 |
+
assert isinstance(module.bias, torch.Tensor)
|
| 53 |
torch.nn.init.zeros_(module.bias)
|
| 54 |
if init_div_is_residual is not False and getattr(module, '_is_residual', False):
|
| 55 |
with torch.no_grad():
|
|
|
|
| 60 |
if std == 0:
|
| 61 |
warnings.warn(f'Embedding layer initialized to 0.')
|
| 62 |
emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
|
|
|
|
|
|
|
| 63 |
elif emb_init_uniform_lim is not None:
|
| 64 |
lim = emb_init_uniform_lim
|
| 65 |
if isinstance(lim, Sequence):
|
|
|
|
| 73 |
lim = [-lim, lim]
|
| 74 |
(a, b) = lim
|
| 75 |
emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
|
|
|
|
|
|
|
| 76 |
else:
|
| 77 |
emb_init_fn_ = init_fn_
|
| 78 |
emb_init_fn_(module.weight)
|
| 79 |
elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
|
| 80 |
+
if hasattr(module, 'weight') and isinstance(module.weight, torch.Tensor):
|
|
|
|
|
|
|
| 81 |
torch.nn.init.ones_(module.weight)
|
| 82 |
+
if hasattr(module, 'bias') and isinstance(module.bias, torch.Tensor):
|
| 83 |
torch.nn.init.zeros_(module.bias)
|
| 84 |
elif isinstance(module, nn.MultiheadAttention):
|
| 85 |
if module._qkv_same_embed_dim:
|
|
|
|
| 108 |
module.out_proj.weight.div_(div_is_residual)
|
| 109 |
if module.out_proj.bias is not None:
|
| 110 |
torch.nn.init.zeros_(module.out_proj.bias)
|
| 111 |
+
elif te is not None and isinstance(module, te.LayerNormMLP):
|
| 112 |
+
if isinstance(module.layer_norm_weight, torch.Tensor):
|
| 113 |
+
torch.nn.init.ones_(module.layer_norm_weight)
|
| 114 |
+
if isinstance(module.layer_norm_bias, torch.Tensor):
|
| 115 |
+
torch.nn.init.zeros_(module.layer_norm_bias)
|
| 116 |
+
init_fn_(module.fc1_weight)
|
| 117 |
+
if module.fc1_bias is not None:
|
| 118 |
+
assert isinstance(module.fc1_bias, torch.Tensor)
|
| 119 |
+
torch.nn.init.zeros_(module.fc1_bias)
|
| 120 |
+
init_fn_(module.fc2_weight)
|
| 121 |
+
if module.fc2_bias is not None:
|
| 122 |
+
assert isinstance(module.fc2_bias, torch.Tensor)
|
| 123 |
+
torch.nn.init.zeros_(module.fc2_bias)
|
| 124 |
+
with torch.no_grad():
|
| 125 |
+
module.fc2_weight.div_(div_is_residual)
|
| 126 |
else:
|
| 127 |
for _ in module.parameters(recurse=False):
|
| 128 |
raise NotImplementedError(f'{module.__class__.__name__} parameters are not initialized by param_init_fn.')
|
| 129 |
|
| 130 |
+
def _normal_init_(std: float, mean: float=0.0) -> Callable:
|
| 131 |
return partial(torch.nn.init.normal_, mean=mean, std=std)
|
| 132 |
|
| 133 |
+
def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
|
| 134 |
del kwargs
|
| 135 |
init_fn_ = _normal_init_(std=std)
|
| 136 |
+
generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
|
|
|
|
|
|
|
| 137 |
|
| 138 |
+
def baseline_param_init_fn_(module: nn.Module, init_std: Optional[float], n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
|
| 139 |
del kwargs
|
| 140 |
if init_std is None:
|
| 141 |
raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.")
|
| 142 |
+
_normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
|
| 143 |
|
| 144 |
+
def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
|
| 145 |
del kwargs
|
| 146 |
std = math.sqrt(2 / (5 * d_model))
|
| 147 |
+
_normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
|
| 148 |
|
| 149 |
+
def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
|
| 150 |
"""From section 2.3.1 of GPT-NeoX-20B:
|
| 151 |
|
| 152 |
An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
|
|
|
|
| 155 |
"""
|
| 156 |
del kwargs
|
| 157 |
residual_div = n_layers / math.sqrt(10)
|
| 158 |
+
small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
|
|
|
|
|
|
|
| 159 |
|
| 160 |
+
def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', **kwargs: Any) -> None:
|
| 161 |
del kwargs
|
|
|
|
|
|
|
| 162 |
kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
|
| 163 |
+
generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
|
| 164 |
|
| 165 |
+
def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', **kwargs: Any) -> None:
|
| 166 |
del kwargs
|
|
|
|
|
|
|
| 167 |
kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
|
| 168 |
+
generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
|
| 169 |
|
| 170 |
+
def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, **kwargs: Any) -> None:
|
| 171 |
del kwargs
|
| 172 |
xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
|
| 173 |
+
generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
|
|
|
|
|
|
|
| 174 |
|
| 175 |
+
def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, **kwargs: Any) -> None:
|
| 176 |
+
del kwargs
|
| 177 |
xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
|
| 178 |
+
generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
|
|
|
|
|
|
|
| 179 |
MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_}
|