Skip to content

Commit

Permalink
Updates on "small" lineValueNetwork
Browse files Browse the repository at this point in the history
  • Loading branch information
landoskape committed Aug 9, 2023
1 parent 308bf10 commit 021eae9
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 101 deletions.
190 changes: 93 additions & 97 deletions dominoes.ipynb

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions dominoesNetworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch import nn
import dominoesFunctions as df


class lineRepresentationNetwork(nn.Module):
def __init__(self, numPlayers, numDominoes, highestDominoe, finalScoreOutputDimension, numOutputCNN=1000, weightPrms=(0.,0.1),biasPrms=0.,actFunc=F.relu,pDropout=0):
super().__init__()
Expand Down Expand Up @@ -67,7 +68,7 @@ def __init__(self, numPlayers, numDominoes, highestDominoe, finalScoreOutputDime
# then, this can be passed as an extra input into a FF network
# the point is to use the same weights on the representations of every single dominoe, then process these transformed representations into the rest of the network
numLineFeatures = 6
numOutputChannels = 3
numOutputChannels = 10
numOutputValues = numOutputChannels * numDominoes
self.cnn_c1 = nn.Conv1d(numLineFeatures, numOutputChannels, 1)
self.cnn_f1 = nn.Linear(numOutputValues, self.numOutputCNN)
Expand All @@ -76,9 +77,9 @@ def __init__(self, numPlayers, numDominoes, highestDominoe, finalScoreOutputDime
self.cnnLayer = nn.Sequential(self.cnn_c1, nn.ReLU(), nn.Flatten(start_dim=0), self.cnn_f1, nn.ReLU(), self.cnn_ln)

# create ff network that integrates the standard network input with the convolutional output
self.fc1 = nn.Linear(self.inputDimension, 50)
self.fc2 = nn.Linear(50, 20)
self.fc3 = nn.Linear(20, 20)
self.fc1 = nn.Linear(self.inputDimension, 100)
self.fc2 = nn.Linear(100, 50)
self.fc3 = nn.Linear(50, 20)
self.fc4 = nn.Linear(20, self.outputDimension)
torch.nn.init.normal_(self.fc1.weight, mean=weightPrms[0], std=weightPrms[1])
torch.nn.init.normal_(self.fc2.weight, mean=weightPrms[0], std=weightPrms[1])
Expand Down

0 comments on commit 021eae9

Please sign in to comment.