Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tf 214 #220

Open
wants to merge 140 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
140 commits
Select commit Hold shift + click to select a range
f62009a
move encoder layers from the policy head to the end of the net body
Arcturai Feb 21, 2022
9acf63f
add in attention policy map
Arcturai Feb 21, 2022
553f9ea
sync and add support for chess transformer
Arcturai Mar 13, 2022
b9f13f6
restore tf.function()'s
Arcturai Mar 13, 2022
6a3e292
bugfixes and DeepNorm implementation {https://arxiv.org/abs/2203.00555}
Arcturai Mar 14, 2022
9dd755a
asdf
Ergodice Mar 19, 2022
5b37de0
try fix for net.py bug
Arcturai Mar 28, 2022
523fa60
correct rule_50 scaling in net.py
Arcturai Mar 30, 2022
42292ab
'policy map' positional encoding
Arcturai Apr 2, 2022
cd1d5d6
Potpourri of architectural improvements
Ergodice Apr 3, 2022
9652e93
Added dense
Ergodice Apr 4, 2022
eec873b
Small changes
Ergodice Apr 8, 2022
17bb790
Sideways and Davit attention, yaml spec improvements
Ergodice Apr 13, 2022
bb26b31
Weight gen, buckets
Ergodice Apr 23, 2022
65d6ecc
Merge pull request #1 from Ergodice/multiple-nets
Ergodice Apr 23, 2022
5baa723
dytalking_heads, fix visible device issue
Ergodice Apr 24, 2022
afd1ad0
Fix DyDense saving issues
Ergodice Apr 26, 2022
b98b1c4
Add fullgen
Ergodice May 8, 2022
99b5b20
Horizontal and vertical convolutions
Ergodice May 12, 2022
636f502
Fix fullgen history
Ergodice May 12, 2022
4462fda
Add dytalking heads
Ergodice Aug 17, 2022
d16db6c
Remove legacy code, add arc encoding and example yaml
Ergodice Aug 22, 2022
50ae5c0
typo in example.yaml
Ergodice Aug 22, 2022
a8e76d8
Update Readme, remove old files
Ergodice Aug 22, 2022
a01b2fa
Typos, bug fixes
Ergodice Aug 25, 2022
1a5a860
Fix checkpointing, remove obsolete stuff
Ergodice Sep 1, 2022
97bc128
Remove dyrelu reference
Ergodice Sep 1, 2022
0e7740e
Remove use_simple_gating, fix activation
Ergodice Sep 1, 2022
6619318
Add metrics, update README
Ergodice Sep 1, 2022
e8c76b2
Smolgen!
Ergodice Nov 22, 2022
886e1ed
Update example.yaml
Ergodice Nov 22, 2022
e896316
Fix typo
Ergodice Nov 22, 2022
8053a38
Merge branch 'attention-net-body' of https://github.com/Ergodice/lcze…
almaudoh Dec 12, 2022
65d7a45
Add smolgen weights to protobuf converters
almaudoh Dec 13, 2022
c200fd2
Add input gating weights
almaudoh Dec 13, 2022
baf54b3
input gating rename.
almaudoh Dec 13, 2022
a096e91
Merge branch 'master' into attention-net-body-smolgen
almaudoh Dec 13, 2022
c0a38e1
Minor fixes
almaudoh Dec 13, 2022
8bb3b36
Remove duplicate class declaration
almaudoh Dec 13, 2022
3ed0360
Merge pull request #2 from almaudoh/attention-net-body-smolgen
Ergodice Dec 14, 2022
e30ce6a
Nadam, RMSprop, better multigpu, onnx continuation
teck45 Dec 25, 2022
e58bd38
tfprocess.py typo fix
teck45 Dec 25, 2022
70b6451
Merge pull request #3 from teck45/patch-1
Ergodice Dec 26, 2022
ffd1a67
Add reducible policy loss
Ergodice Dec 26, 2022
1d402a9
Remove talking heads and glu
Ergodice Dec 26, 2022
d006889
Add fast depthwise process
Ergodice Dec 26, 2022
cdc76ca
Add attn_wts to network outputs
Ergodice Dec 26, 2022
580460d
Add optimizers to example.yaml
Ergodice Dec 26, 2022
49f54c8
Remove comment
Ergodice Dec 26, 2022
bb06d69
Update README.md
Ergodice Dec 26, 2022
7922c43
Add BT3 improvements
Ergodice Jan 10, 2023
2ac48cf
Remove obsolete modules
Ergodice Jan 15, 2023
b4e7203
Add sparsity support
Ergodice Jan 15, 2023
1544b97
Prevent sparsity from turning on automatically
Ergodice Jan 16, 2023
0c10d4e
Update README.md
Ergodice Jan 16, 2023
5bdc121
Revert to ln and pol loss -> reducible pol loss
Ergodice Jan 18, 2023
28189aa
Revert to default activation and layernorm
Ergodice Jan 18, 2023
3d9f947
Refactor ffn activation
Ergodice Jan 31, 2023
884b994
Switch single to double quotes following PEP 8
Ergodice Jan 31, 2023
cbfe579
Replace single quotes with double
Ergodice Jan 31, 2023
e57af20
Single to double quotes in train.py
Ergodice Jan 31, 2023
b881e66
Add categorical value
Ergodice Feb 25, 2023
40df352
Turn off arc_encoding in example.yaml
Ergodice Feb 25, 2023
ddd18b4
Remove sqrrelu from example.yaml
Ergodice Feb 25, 2023
b1febf4
Refactor value cat loss
Ergodice Feb 25, 2023
7bf0b07
Fix make_value_buckets bug
Ergodice Feb 28, 2023
6d510e7
Clarify comment
Ergodice Feb 28, 2023
ebed2c6
Clarify reducible_policy_loss
Ergodice Feb 28, 2023
49251d8
fix default activation bug
Ergodice Mar 2, 2023
cf0bd56
Set ffn activation lower down
Ergodice Mar 2, 2023
813c3ca
Describe categorical value
Ergodice Mar 2, 2023
b1460bf
Clarify cvh in Readme
Ergodice Mar 2, 2023
3d05a3a
Fix math error in Readme
Ergodice Mar 2, 2023
ecd1b44
Add roadmap
Ergodice Apr 6, 2023
9dcc7d9
Add BT3 changes
Ergodice Apr 8, 2023
adc2b25
switch kernel initializer to fan_avg
Ergodice Apr 8, 2023
55f65d1
More BT3 improvements
Ergodice May 4, 2023
a69a69e
Refactor to be more similar to official code
Ergodice May 5, 2023
38582e7
Add new net protobuf
Ergodice May 5, 2023
6f88d80
Couple missed net arguments
Ergodice May 5, 2023
0104192
Fix multigpu training
Ergodice May 5, 2023
45e803d
Revert "Fix multigpu training"
Ergodice May 5, 2023
f741d62
Enforce ffn and add seed to xavier init
Ergodice May 5, 2023
4b9e143
Add temporary proto hack to deal with versions
Ergodice May 5, 2023
e300a2c
"norm" -> "ln" in net.py
Ergodice May 6, 2023
e1258d6
Fix multigpu training
Ergodice May 6, 2023
3e9e6ec
Fix proto enum reuse
Ergodice May 6, 2023
a545a11
Support more optimizers
Ergodice May 7, 2023
5d6ad77
Fix half-precision training
Ergodice May 7, 2023
1a38976
Switch val err to have highest prto enum values
Ergodice May 8, 2023
9e8507e
Initial quantization commit
Ergodice May 14, 2023
24aca65
Support mish activation
Ergodice May 14, 2023
0712e93
Update net.proto with quant scales
Ergodice May 15, 2023
0578e33
Revert to pre quant
Ergodice May 20, 2023
440982a
Add quantization file
Ergodice May 20, 2023
f0972ae
More efficient round implementation
Ergodice May 20, 2023
f2e8bf2
Fix in_units in QuantizedDense
Ergodice May 20, 2023
e5b2017
Fix bias shape and add bias and activation support
Ergodice May 20, 2023
0c2082e
Disable emb ln (not ready in previous proto version)
Ergodice May 22, 2023
90736b8
Efficient grad scale
Ergodice May 22, 2023
51fb562
Split init_net into functions
Ergodice May 22, 2023
93bdbcb
Remove support for conv weights
Ergodice May 22, 2023
a790eda
Remove unused variable
Ergodice May 22, 2023
6470921
Remove quantization
Ergodice Jul 4, 2023
b05d625
Update BT3 proto
Ergodice Jul 14, 2023
9c664d9
Update with BT3 improvements
Ergodice Jul 14, 2023
2c34955
Refactor and bugfixes
Ergodice Aug 4, 2023
dc62f82
Make BT3 features optional
Ergodice Aug 4, 2023
d5d8aed
Add moe, assign default None to output
Ergodice Aug 4, 2023
dce28f2
Refactor losses
Ergodice Aug 8, 2023
e26692b
Add and clean up heads
Ergodice Aug 18, 2023
7ec4b47
Move sqrt in optwgt gen
Ergodice Aug 19, 2023
82bd638
Fix bug with optwgt shape
Ergodice Aug 22, 2023
055e903
Remove policy_val
Ergodice Aug 23, 2023
bb25b71
Add opponent policy
Ergodice Aug 24, 2023
14ed59c
Revert "Add opponent policy"
Ergodice Aug 27, 2023
b0626de
Fix q_st rescoring
Ergodice Aug 27, 2023
6b55227
Simplify apply_alpha
Ergodice Aug 27, 2023
c51152c
Change sign on alpha exponent
Ergodice Aug 28, 2023
4905804
Revert "Change sign on alpha exponent"
Ergodice Aug 28, 2023
8220533
Remove policy val loss
Ergodice Aug 30, 2023
52231ad
Fix net.py to match current proto.
almaudoh Sep 6, 2023
64b9ba5
Fix policy embedding reference for new policy heads.
almaudoh Sep 7, 2023
39ce9aa
Update policy in net.py to match proto
Ergodice Sep 17, 2023
00166e9
Fix typo in tfprocess.py
Ergodice Sep 17, 2023
d6f59b9
Remove arc encoding
Ergodice Sep 24, 2023
cb590eb
Set new network format for multihead nets.
almaudoh Oct 3, 2023
e1a71d9
Update tfprocess.py and net.py to save new embedding and version to p…
almaudoh Oct 6, 2023
1eacb21
Merge pull request #5 from almaudoh/attention-net-updates
Ergodice Oct 11, 2023
8d339ae
Get mixed precision working
Ergodice Oct 11, 2023
7ccd535
Add categorical value
Ergodice Oct 11, 2023
3a46f31
Add BT4 improvements
Ergodice Oct 11, 2023
80d46a7
Update lczero-common
Ergodice Oct 11, 2023
a0a473f
Fix policy loss bug
Ergodice Oct 14, 2023
cb2cf30
Reemove redundant code in TFProcess init
Ergodice Oct 20, 2023
dc83a58
Initial tf 2.14 commit
Ergodice Oct 31, 2023
8acd286
Allow disabling other biases
Ergodice Nov 1, 2023
95c2c1a
Support tf2.10 and fix future heads
Ergodice Nov 5, 2023
1317cc3
Turn on test reporting
Ergodice Nov 5, 2023
dbfb7fd
Update net.proto
Ergodice Nov 8, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,8 @@ venv.bak/

# protobuf stuff
tf/proto/

# runs
*.v2
leelalogs/
networks/
98 changes: 60 additions & 38 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,63 @@
# WARNING: THIS BRANCH REQUIRES TENSORFLOW 2.13+ AND HAS ONLY BEEN TESTED WITH 2.14

# What's new here?

I've added a few features and modules to the training pipeline. Please direct any questions to the Leela Discord server.


# BT3 Improvements
Building of the progress we made with ffn sizes, optimizers, and smolgen, BT3 brings several new improvements. The first is an improved embedding which adds a dense layer from a full representation of the board state (64 * 12 values) to a list of 64 embedding vectors, followed by an ffn. Arcturai's positional encoding has been removed since it seems to slightly degrade the performance of smolgen-equipped transformers. Previously, most of the attention heads in the first attention layer did not encode any useful information. Now they have some information to work on. The policy accuracy gain is a minor 0.3%, though it comes at negligible latency increase.

Support for layer normalization without centering and adding biases has been added as well. This results in negligible accuracy drop and should increase throughput by a couple percent, especially for smaller models.

There are also a significant number of auxiliary heads inspired by work on the Katago Go engine. One of the new ones is a "policy-value" head, which for each move predicts the value of that move weighted by the proportion of playouts of that node. It can be thought of as predicting the value given that the move was found to be the best, and may allow us to find checkmates and decisive lines faster.


# Quality of life
There are three quality of life improvements: a progress bar, new metrics, and pure attention code

Progress bar: A simple progress bar implemented in the Python `rich` module displays the current steps (including part-steps if the batches are split) and the expected time to completion.

Pure attention: The pipeline no longer contains any code from the original ResNet architecture. This makes for clearer yamls and code. The protobuf has been updated to support smolgen, input gating, and the square relu activation function.

## More metrics

I've added train value accuracy and train policy accuracy for the sake of completeness and to help detect overfitting. The speed difference is negligible. There are also three new losses metrics to evaluate policy. The cross entropy we are currently using is probably still the best for training, though we could try instead to turn the task into a classification problem, effectively using one-hot vectors at the targets' best moves, though this would run the risk of overfitting.

Thresholded policy accuracies: the thresholded policy accuracy @x% is the percent of moves for which the net has policy at least x% at the move the target thinks is best.

Reducible policy loss is the amount of policy loss we can reduce, i.e., the policy loss minus the entropy of the policy target.

The search policy loss is designed to loosely describe how long it would take to find the best move in the average position. It is implemented as the average of the multiplicative inverses of the network's policies at the targets' top moves, or one over the harmonic mean of those values. This is not too accurate since the search algorithm will often give up on moves the network does not like unless they provide returns that the network can immediately recognize.


# Architectural improvements
There are a few architectural improvements I've introduced. I only list the apparently useful ones here. For reference, doubling the model size (i.e., 40% larger embeddings or 100% more layers) seems to add 1.5% policy accuracy at 8h/384/512dff.

Note that the large expansion ratios of 4x in the models I report here are not as useful in larger models. A 1536dff outperforms a 4096dff at 10x 8h/1024.

The main improvement is smolgen, which adds 2% policy accuracy to a 10x 8h/384/512dff model.

I've also allowed for the model to train with sparsity so that we can increase throughput on the Ada and Hopper generations of Nvidia GPUs.


## Smolgen

Smolgen is the best improvement by far. It adds around 2% policy accuracy to a 10x 8h/384/512dff model. The motivation is simple: how can we encode global information into self-attention? The encoder architecture has two stages: self-attention and a dense feedforward layer. Self-attention only picks out information shaired between pairs of squares, while the feedforward layer looks at only one square. The only way global information enters the picture is through the softmax, but this cannot be expected to squeeze any significant information out.

Of course repeated application of self-attention is sufficient with large enough embedding sizes and layers, but chess is fundamentally different from image recognition and NLP. The encoder architecture effectively partitions inputs into nodes and allows them at each layer to spread information and then do some postprocessing with the results. This works in image recognition since it makes sense to compare image patches and the image can be represented well by patch embeddings at each patch. In NLP, the lexical tokens are very suited for this spread of information since the simple structures of grammar allows self-attention (with distance embeddings of course so that tokens can interact locally at first).

Compared to these problems, chess is a nightmare. What it means that there are two rooks on the same file depends greatly on whether there are pieces between them. Even while the transformer architecture provides large gains against ResNets, which are stuck processing local information, it is still not suited for a problem which requires processing not at the between-square level bet on the global level.

The first solution was logit gating. Arcturai observed that adding embeddings which represent squares which can be reached through knight, bishop, or rook moves vastly improved the architecture. My first attempt at improving upon this was logit gating. Because the board structure is fixed, it makes sense to add an additive offset to the attention logits so that heads can better focus on what is important. A head focusing on diagonals could have its gating emphasize square-pairs which lie on a same diagonal. I achieved further improvements applying multiplicative factors to the attention weights.

This solution works well, but still has its shortcomings. In particular, it is static. We'd like our offsets to change with the input. If pawns are blocking out a diagonal, we would like to greatly reduce the information transfer between pieces on that diagonal. This leads to fullgen, which dynamically generates additional attention logits from the board state. Because it should focus on spatial information and a 13-hot vector is sufficient to completely describe the board state, the representation is generated by applying a dense layer to compress each square into a representation of size 32 (of course, the embeddings already contain processed information which will be useful in the computation).

This is then flattened and put through two dense layers with hidden sizes 256 and hx256 and swish nonlinearities (barely better than relu). Finally, a dense layer (256 -> 4096) is applied to hx64x64, where h is the number of heads. This is added to the attention logits which are computed regularly. This last dense layer is extremely parameter intensive, so it is shared across all layers. This works well in practice.

Smolgen adds about a percent policy accuracy for +10% latency, which is well worth the cost. Increasing the number of heads so that each has size 16 adds ~0.5% pol acc to a 10x 8h/384/512dff since smolgen can do a lot of heavy lifting, but it is not clear whether this is worth the latency cost.


# Training

The training pipeline resides in `tf`, this requires tensorflow running on linux (Ubuntu 16.04 in this case). (It can be made to work on windows too, but it takes more effort.)
Expand All @@ -19,44 +79,6 @@ tar -xzf training-run1--20200711-2017.tar

Now that the data is in the right format one can configure a training pipeline. This configuration is achieved through a yaml file, see `training/tf/configs/example.yaml`:

```yaml
%YAML 1.2
---
name: 'kb1-64x6' # ideally no spaces
gpu: 0 # gpu id to process on

dataset:
num_chunks: 100000 # newest nof chunks to parse
train_ratio: 0.90 # trainingset ratio
# For separated test and train data.
input_train: '/path/to/chunks/*/draw/' # supports glob
input_test: '/path/to/chunks/*/draw/' # supports glob
# For a one-shot run with all data in one directory.
# input: '/path/to/chunks/*/draw/'

training:
batch_size: 2048 # training batch
total_steps: 140000 # terminate after these steps
test_steps: 2000 # eval test set values after this many steps
# checkpoint_steps: 10000 # optional frequency for checkpointing before finish
shuffle_size: 524288 # size of the shuffle buffer
lr_values: # list of learning rates
- 0.02
- 0.002
- 0.0005
lr_boundaries: # list of boundaries
- 100000
- 130000
policy_loss_weight: 1.0 # weight of policy loss
value_loss_weight: 1.0 # weight of value loss
path: '/path/to/store/networks' # network storage dir

model:
filters: 64
residual_blocks: 6
...
```

The configuration is pretty self explanatory, if you're new to training I suggest looking at the [machine learning glossary](https://developers.google.com/machine-learning/glossary/) by google. Now you can invoke training with the following command:

```bash
Expand Down
6 changes: 1 addition & 5 deletions init.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1 @@
#!/usr/bin/env bash

protoc --proto_path=libs/lczero-common --python_out=tf libs/lczero-common/proto/net.proto
protoc --proto_path=libs/lczero-common --python_out=tf libs/lczero-common/proto/chunk.proto
touch tf/proto/__init__.py
protoc --proto_path=libs/lczero-common --python_out=tf libs/lczero-common/proto/net.proto; protoc --proto_path=libs/lczero-common --python_out=tf libs/lczero-common/proto/chunk.proto; touch tf/proto/__init__.py
2 changes: 1 addition & 1 deletion libs/lczero-common
Submodule lczero-common updated 1 files
+131 −0 proto/net.proto
34 changes: 31 additions & 3 deletions tf/attention_policy_map.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import tensorflow as tf


move = np.arange(1, 8)
Expand Down Expand Up @@ -28,6 +29,7 @@
[1 + 2*8]
])


promos = np.array([2*8, 3*8, 4*8])
pawn_promotion = np.array([
-1 + promos,
Expand Down Expand Up @@ -65,12 +67,17 @@ def make_map():
)
)
z = np.zeros((64*64+8*24, 1858), dtype=np.int32)
apm_out = np.zeros((1858,), dtype=np.int32)
apm_in = np.zeros((64*64+8*24), dtype=np.int32)
# first loop for standard moves (for i in 0:1858, stride by 1)
i = 0
for pickup_index, putdown_indices in enumerate(traversable):
for putdown_index in putdown_indices:
if putdown_index < 64:
z[putdown_index + (64*pickup_index), i] = 1
du_idx = putdown_index + (64*pickup_index)
z[du_idx, i] = 1
apm_out[i] = du_idx
apm_in[du_idx] = i
i += 1
# second loop for promotions (for i in 1792:1858, stride by ls[j])
j = 0
Expand All @@ -87,8 +94,29 @@ def make_map():
pickup_file = pickup_index % 8
promotion_file = putdown_index % 8
promotion_rank = (putdown_index // 8) - 8
z[4096 + pickup_file*24 + (promotion_file*3+promotion_rank), i] = 1
du_idx = 4096 + pickup_file*24 + (promotion_file*3+promotion_rank)
z[du_idx, i] = 1
apm_out[i] = du_idx
apm_in[du_idx] = i
i += ls[j]
j += 1

return z
return z, apm_out, apm_in

apm_map, apm_out, apm_in = make_map()

def set_zero_sum(x):
x = x + (1 - tf.reduce_sum(x, axis=1, keepdims=True)) * (
1.0 / 64)
return x

def get_up_down(moves):

out = tf.matmul(moves, apm_map, transpose_b=True)
out = out[..., :64*64]
out = tf.reshape(out, [-1, 64, 64])
pu = set_zero_sum(tf.reduce_sum(out, axis=-1))
pd = set_zero_sum(tf.reduce_sum(out, axis=-2))
print(pu.shape, pd.shape)
return pu, pd

10 changes: 8 additions & 2 deletions tf/chunkparsefunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import tensorflow as tf


def parse_function(planes, probs, winner, q, plies_left):
def parse_function(planes, probs, winner, q, plies_left, st_q, opp_idx, next_idx):
"""
Convert unpacked record batches to tensors for tensorflow training
"""
Expand All @@ -27,11 +27,17 @@ def parse_function(planes, probs, winner, q, plies_left):
winner = tf.io.decode_raw(winner, tf.float32)
q = tf.io.decode_raw(q, tf.float32)
plies_left = tf.io.decode_raw(plies_left, tf.float32)
st_q = tf.io.decode_raw(st_q, tf.float32)
opp_idx = tf.io.decode_raw(opp_idx, tf.int32)
next_idx = tf.io.decode_raw(next_idx, tf.int32)

planes = tf.reshape(planes, (-1, 112, 8, 8))
probs = tf.reshape(probs, (-1, 1858))
winner = tf.reshape(winner, (-1, 3))
q = tf.reshape(q, (-1, 3))
plies_left = tf.reshape(plies_left, (-1, 1))
st_q = tf.reshape(st_q, (-1, 3))
opp_idx = tf.reshape(opp_idx, (-1,))
next_idx = tf.reshape(next_idx, (-1,))

return (planes, probs, winner, q, plies_left)
return (planes, probs, winner, q, plies_left, st_q, opp_idx, next_idx)
Loading