Conversation
* add asserts and fix post training readme * precommit --------- Co-authored-by: Quentin Anthony <qganthony@yahoo.com>
* fix typo * fix neoxargs usage test * skip conversion test due to multiprocessing issue * precommit --------- Co-authored-by: Quentin Anthony <qganthony@yahoo.com>
* Add ERROR logging prefix and sort alphabetically * fix comment
|
|
||
|
|
||
|
|
||
| - **dim_att**: int |
There was a problem hiding this comment.
we should either have unified args (across mamba, rwkv, transformers) for these, or prepend these args with whatever block type they're targeting (e.g. rwkv_dim_att).
|
|
||
| "num_layers": 24, | ||
| "hidden_size": 1024, | ||
| "num_attention_heads": 16, # head_size = dim_att / num_attention_heads. |
There was a problem hiding this comment.
Similar comment here. Calling these attention heads is highly misleading.
There was a problem hiding this comment.
I kind of disagree, as rwkv code generally references time mixing as attention, and the RWKV kernel is often called a type of "linear attention." But, I can add a bunch of configs to decouple rkwv and transformer config options, but this will just create a lot of config args that have essentially the same purpose in my opinion.
| except ModuleNotFoundError: | ||
| print( | ||
| "Unable to import RWKV FLA kernels. Install them from our requirements/requirements-rwkv.txt, \ | ||
| or directly from https://github.com/sustcsonglin/flash-linear-attention.git, or use CUDA kernels." |
There was a problem hiding this comment.
This last point "or use CUDA kernels" is confusing. Can you add a "by doing xyz" so that users know what you mean?
| @@ -104,7 +126,7 @@ class RWKV_TimeMix(nn.Module): | |||
| TODO: fix jit compiling. | |||
There was a problem hiding this comment.
Is this based on the parser issue we discussed? I think it's worth testing just-jit and reordered jit and heuristics like I suggested before merging with this TODO
| self.ffn = RWKV_ChannelMix(neox_args, layer_number) | ||
| self.ffn = ParallelRWKV_ChannelMix(neox_args, layer_number, init_method=init_method) | ||
|
|
||
| if neox_args.attention_dropout > 0: |
There was a problem hiding this comment.
another attention arg for rwkv. Can we decouple attn dropout from rwkv?
| WARNING = f"{YELLOW}[WARNING]{END}" | ||
|
|
||
| ### Formatted logging prefixes ### | ||
| ERROR = f"{RED}[ERROR]{END} " |
There was a problem hiding this comment.
I don't think we've properly merged this branch onto upstream main, since this is tracking as a change. Please do this.
| } | ||
| """ | ||
|
|
||
| rwkv_fla: bool = False |
There was a problem hiding this comment.
regen neox_arguments.md, since this isn't showing up there.
| datatype=datatype, | ||
| ) | ||
| elif neox_args.train_impl == "kto": | ||
| assert ( |
There was a problem hiding this comment.
I think these will also go away with a proper rebase onto latest main
…ype' option was removed (#1309) * fix 'intermediate_size' in Llama configuration files after the 'mlp_type' option was removed * config adjustments for llama and gated activations * pre-commit --------- Co-authored-by: jahatef <hatef.4@buckeyemail.osu.edu> Co-authored-by: Quentin Anthony <qganthony@yahoo.com>
* Python 3.10 support In this issue Python 3.10 support was added #1122 * update wording on torch and python --------- Co-authored-by: Quentin Anthony <qganthony@yahoo.com>
* adds pyproject files and tests * formatting and add dev packages to dev req files * improve req testing --------- Co-authored-by: Quentin Anthony <qganthony@yahoo.com>
Adds tensor parallel implementation for rwkv, and support for Triton FLA implementation in GPT-NeoX