Skip to content

Commit

Permalink
Performances improvement #12 and #16
Browse files Browse the repository at this point in the history
  • Loading branch information
MissonO committed Mar 26, 2024
1 parent 6edd0a0 commit 8b467a5
Show file tree
Hide file tree
Showing 71 changed files with 6,627 additions and 4,355 deletions.
61 changes: 51 additions & 10 deletions src/MIC/Assets/.MeshCloudScripting/Presentation1/App.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ public class App : IHostedService, IAsyncDisposable
private readonly AppSettings _appSettings;
private readonly PersonFlow _personFlow;
private readonly List<QLearning> _qLearnings;
private CancellationTokenSource cts1 = new CancellationTokenSource();
private CancellationTokenSource cts2 = new CancellationTokenSource();
private CancellationTokenSource cts3 = new CancellationTokenSource();
private CancellationTokenSource cts4 = new CancellationTokenSource();
private CancellationTokenSource cts5 = new CancellationTokenSource();
private Dictionary<string, Vector3> destinationsList = new Dictionary<string, Vector3>
{
{"Cafe", new Vector3(-2, 0.1f, 3)}, //Machine à café
{"Innover", new Vector3(-27, 0.1f, 5)}, // Innover
{"Loft", new Vector3()}, // Loft
};
public App(ICloudApplication app, ILogger<App> logger)
{
_app = app;
Expand All @@ -39,7 +50,7 @@ public App(ICloudApplication app, ILogger<App> logger)
for (int i = 0; i < 5; i++)
{
int numStates = 1800;
_qLearnings.Add(new QLearning(numStates, 8, 0.7, 0.9, 0.6, 0));
_qLearnings.Add(new QLearning(numStates, 8, 0.7, 0.7, 1, 0));
}
}
private AppSettings? LoadSettings()
Expand Down Expand Up @@ -80,21 +91,51 @@ public async Task StartAsync(CancellationToken token)
await UploadImageToBlobStorage(2, _appSettings);
btnSphere.IsActive = false;
};
Vector3 destination = new Vector3(-25, 0, 1);
var wall = (TransformNode)_app.Scene.FindChildByPath("QLearning/Wall");

var btnSimulationCafe = (TransformNode)_app.Scene.FindChildByPath("Simulation/ButtonCafe");
var sensorCafe = btnSimulationCafe.FindFirstChild<InteractableNode>();
var btnSImulationInnover = (TransformNode)_app.Scene.FindChildByPath("Simulation/ButtonInnover");
var sensorInnover = btnSImulationInnover.FindFirstChild<InteractableNode>();
sensorCafe.Selected += async (sender, args) =>
{
await StartSimulation("Cafe");
};
sensorInnover.Selected += async (sender, args) =>
{
await StartSimulation("Innover");
};
}

public async Task StartSimulation(string simulationAction) //0 to go to the coffee
{
Vector3 destination = destinationsList[simulationAction];
var npc1 = (TransformNode)_app.Scene.FindChildByPath("HumanMale_Character");
var npc2 = (TransformNode)_app.Scene.FindChildByPath("HumanMale_Character1");
var npc3 = (TransformNode)_app.Scene.FindChildByPath("HumanMale_Character2");
var npc4 = (TransformNode)_app.Scene.FindChildByPath("HumanMale_Character3");
var npc5 = (TransformNode)_app.Scene.FindChildByPath("HumanMale_Character4");
_qLearnings[0].MoveAction(npc1, destination, 0, 10000);
//_qLearnings[1].MoveAction(npc2, destination, 1, 100);
//_qLearnings[2].MoveAction(npc3, destination, 10000);
//_qLearnings[3].MoveAction(npc4, destination, 10000);
//_qLearnings[4].MoveAction(npc5, destination, 10000);

//var move = 5;
//_personFlow.Boucle(npc, move);
cts1.Cancel();
cts2.Cancel();
cts3.Cancel();
cts4.Cancel();
cts5.Cancel();

cts1 = new CancellationTokenSource();
cts2 = new CancellationTokenSource();
cts3 = new CancellationTokenSource();
cts4 = new CancellationTokenSource();
cts5 = new CancellationTokenSource();

_qLearnings[0].MoveAction(npc1, destination, simulationAction, 0, 10000, cts1.Token);
await Task.Delay(50);
_qLearnings[1].MoveAction(npc2, destination, simulationAction, 1, 10000, cts2.Token);
await Task.Delay(50);
_qLearnings[2].MoveAction(npc3, destination, simulationAction, 2, 10000, cts3.Token);
await Task.Delay(50);
_qLearnings[3].MoveAction(npc4, destination, simulationAction, 3, 10000, cts4.Token);
await Task.Delay(50);
_qLearnings[4].MoveAction(npc5, destination, simulationAction, 4, 10000, cts5.Token);
}

public async Task<string> GetImage(TransformNode node, string imageUrl)
Expand Down
172 changes: 99 additions & 73 deletions src/MIC/Assets/.MeshCloudScripting/Presentation1/QLearning.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.IO;
using System.Numerics;
using Microsoft.Mesh.CloudScripting;
Expand All @@ -9,27 +10,38 @@ namespace Presentation1
{
public class QLearning
{
private double[,] qTable; // Table for the memory of the npc, works with a reward value
private Dictionary<int, double> maxQValues = new Dictionary<int, double>();
private Dictionary<int, int> maxQActions = new Dictionary<int, int>();
private int numStates, numActions; //numState : number of stats possible | numActions : number of actions posible
private double learningRate, discountFactor, explorationRate; // Parameters for the Q learning algorithm
private float lastDistance = 9999; // Stores the last distance from the choosen destination
private Random rnd = new Random();
const int GRID_SIZE = 60;
const int STATE_MODULUS = 100000;
private Dictionary<int, Vector3> actionDirections = new Dictionary<int, Vector3>
const float DISTANCE_THRESHOLD = 0.1f;
const int REWARD_GOAL = 100;
const int REWARD_FAR = -1;
const int REWARD_CLOSE = 1;

private readonly double[,] qTable; // Table for the memory of the npc, works with a reward value
private readonly int numStates, numActions; //numState : number of stats possible | numActions : number of actions posible
private readonly double learningRate, discountFactor, explorationRate; // Parameters for the Q learning algorithm
private readonly bool[,] gridObstacles;
private readonly object lockObject = new object();
private readonly Random rnd = new Random();

private List<Vector3> npcPositions = new List<Vector3>();

private float lastDistance = 9999; // Stores the last distance from the choosen destination

private readonly Dictionary<int, double> maxQValues = new Dictionary<int, double>();
private readonly Dictionary<int, int> maxQActions = new Dictionary<int, int>();
private readonly Dictionary<float, Vector3> actionDirections = new Dictionary<float, Vector3>
{
{0, new Vector3(0, 0, 1)}, // Move up
{1, new Vector3(0, 0, -1)}, // Move down
{2, new Vector3(-1, 0, 0)}, // Move left
{3, new Vector3(1, 0, 0)}, // Move right
{4, new Vector3(1, 0f, 1)}, // Move up-right
{5, new Vector3(-1, 0f, 1)}, // Move up-left
{6, new Vector3(1, 0f, -1)}, // Move down-right
{7, new Vector3(-1, 0f, -1)} // Move down-left
{0, new Vector3(0, 0, 1f)}, // Move up
{1, new Vector3(0, 0, -1f)}, // Move down
{2, new Vector3(-1f, 0, 0)}, // Move left
{3, new Vector3(1f, 0, 0)}, // Move right
{4, new Vector3(1f, 0f, 1f)}, // Move up-right
{5, new Vector3(-01f, 0f, 1f)}, // Move up-left
{6, new Vector3(1f, 0f, -1f)}, // Move down-right
{7, new Vector3(-1f, 0f, -1f)} // Move down-left
};
private bool[,] gridObstacles;

public QLearning(int numStates, int numActions, double learningRate, double discountFactor, double explorationRate, int npcNum)
{
this.numStates = numStates;
Expand All @@ -40,7 +52,6 @@ public QLearning(int numStates, int numActions, double learningRate, double disc

qTable = new double[numStates, numActions];

LoadQTable(npcNum);
gridObstacles = LoadGrid();
}

Expand All @@ -55,22 +66,23 @@ public int ChooseAction(int state)
}
else
{
maxQActions.TryGetValue(state, out int action);
// Exploit: select the action with max value
return maxQActions.ContainsKey(state) ? maxQActions[state] : 0;
return action;
}
}

// Update the Q value in the Q table with the reward it gets
public void UpdateQValue(int prevState, int action, float reward, int nextState)
{
double oldValue = qTable[prevState, action];
if (!maxQValues.ContainsKey(prevState) || qTable[prevState, action] > maxQValues[prevState])
{
double learnedValue = reward + discountFactor * (maxQValues.ContainsKey(nextState) ? maxQValues[nextState] : 0);
qTable[prevState, action] += learningRate * (learnedValue - oldValue);
maxQValues[prevState] = qTable[prevState, action];
maxQActions[prevState] = action;
}
double oldValue = qTable[prevState, action];
if (!maxQValues.ContainsKey(prevState) || qTable[prevState, action] > maxQValues[prevState])
{
double learnedValue = reward + discountFactor * (maxQValues.ContainsKey(nextState) ? maxQValues[nextState] : 0);
qTable[prevState, action] += learningRate * (learnedValue - oldValue);
maxQValues[prevState] = qTable[prevState, action];
maxQActions[prevState] = action;
}
}

// Get the position/state of the npc
Expand All @@ -86,17 +98,22 @@ public int GetState(TransformNode npc)
}

// Main function that make the npc move and calls all the subfunctions
public async void MoveAction(TransformNode npc, Vector3 destination, int npcNum, int numIterations)
public async void MoveAction(TransformNode npc, Vector3 destination, string destinationName, int npcNum, int numIterations, CancellationToken cancellationToken)
{
LoadQTable(npcNum, destinationName);
for (int i = 0; i < numIterations; i++)
{
if (cancellationToken.IsCancellationRequested)
{
return;
}
int prevState = GetState(npc);
int action = ChooseAction(prevState);

Vector3 direction = actionDirections[action];

await RotateNpc(npc, direction);
await MoveNpc(npc, direction, npcNum);
await RotateNpc(npc, direction, cancellationToken);
await MoveNpc(npc, direction, npcNum, cancellationToken);

// Calculate the reward
float reward = CalculateReward(npc, destination);
Expand All @@ -107,39 +124,51 @@ public async void MoveAction(TransformNode npc, Vector3 destination, int npcNum,
UpdateQValue(prevState, action, reward, nextState);
if (i % 100 == 0)
{
SaveQTable(npcNum);
SaveQTable(npcNum, destinationName);
}
if (Vector3.Distance(npc.Position, destination) <= DISTANCE_THRESHOLD) // 0.1f is a small threshold to account for floating point precision
{
break; // Stop the movement
}
}
}

public async Task MoveNpc(TransformNode npc, Vector3 direction, int npcNum)
public async Task MoveNpc(TransformNode npc, Vector3 direction, int npcNum, CancellationToken cancellationToken)
{
float duration = 2f;
float remainingTime = duration;
Vector3 desiredPosition = npc.Position + direction;

if (npcPositions.Any(pos => Vector3.Distance(pos, desiredPosition) < DISTANCE_THRESHOLD))
{
// If there is a collision, return without moving the NPC
return;
}

if (desiredPosition.X >= -GRID_SIZE / 2 && desiredPosition.X < GRID_SIZE / 2 && desiredPosition.Z >= -GRID_SIZE / 2 && desiredPosition.Z < GRID_SIZE / 2 && !gridObstacles[(int)desiredPosition.X + GRID_SIZE / 2, (int)desiredPosition.Z + GRID_SIZE / 2])
{
float t = 0f;
while (t < 1f)
{
float delatTime = 0.01f;
float stepSize = delatTime / duration;
if (cancellationToken.IsCancellationRequested)
{
return;
}
float deltaTime = 0.02f;
float stepSize = deltaTime / duration;
t += stepSize;
npc.Position = Vector3.Lerp(npc.Position, desiredPosition, t);

if (t > 1f) t = 1f;
await Task.Delay((int)(delatTime * 1000));
remainingTime -= delatTime;
await Task.Delay((int)(deltaTime * 100));
}
}

}

public async Task RotateNpc(TransformNode npc, Vector3 direction)
public async Task RotateNpc(TransformNode npc, Vector3 direction, CancellationToken cancellationToken)
{
Vector3 normalizedDirection = Vector3.Normalize(direction);

float rotationAngleRadians = MathF.Atan2(normalizedDirection.X, normalizedDirection.Z);
if (rotationAngleRadians == 0)
{
Expand All @@ -151,21 +180,23 @@ public async Task RotateNpc(TransformNode npc, Vector3 direction)
Quaternion rotation = new Quaternion(rotationAxis.X, rotationAxis.Y, rotationAxis.Z, MathF.Cos(rotationAngleRadians / 2));
if (Equals(npc.Rotation, rotation)) return;

float duration = 1f;
float remainingTime = duration;
float duration = 0.5f;
float t = 0f;
while (t < 1f)
{
float deltaTime = 0.01f; // Adjust as needed
if (cancellationToken.IsCancellationRequested)
{
return;
}
float deltaTime = 0.02f; // Adjust as needed
float stepSize = deltaTime / duration;

t += stepSize;
npc.Rotation = Quaternion.Slerp(npc.Rotation, rotation, t);
npc.Rotation = Quaternion.Slerp(npc.Rotation, rotation, t * t * (3 - 2 * t));

if (t > 1f) t = 1f;

await Task.Delay((int)(deltaTime * 1000)); // Convert deltaTime to milliseconds
remainingTime -= deltaTime;
await Task.Delay((int)(deltaTime * 100));
}
}

Expand All @@ -177,52 +208,47 @@ public int CalculateReward(TransformNode npc, Vector3 destination)

if (distance == 0)
{
return 100; // Big reward for reaching the goal
}
else if (distance >= lastDistance)
{
lastDistance = distance;
return -1;
return REWARD_GOAL; // Big reward for reaching the goal
}
else
{
lastDistance = distance;
return 1;
return -1 * (int)distance;
}
}

// At the end of the movement, save the Q table in a json to exploit it at the next launchs
public void SaveQTable(int npcNum)
public void SaveQTable(int npcNum, string destinationName)
{
var qTableList = new List<List<double>>();
for (int i = 0; i < numStates; i++)
{
var row = new List<double>();
for (int j = 0; j < numActions; j++)
var qTableList = new List<List<double>>();
for (int i = 0; i < numStates; i++)
{
row.Add(qTable[i, j]);
var row = new List<double>();
for (int j = 0; j < numActions; j++)
{
row.Add(qTable[i, j]);
}
qTableList.Add(row);
}
qTableList.Add(row);
}
string finalFilePath = Path.Combine(AppDomain.CurrentDomain.BaseDirectory, "qtable" + npcNum + ".json"); ;
File.WriteAllText(finalFilePath, JsonConvert.SerializeObject(qTableList));
string finalFilePath = Path.Combine(AppDomain.CurrentDomain.BaseDirectory, "qtable" + npcNum + destinationName + ".json"); ;
File.WriteAllText(finalFilePath, JsonConvert.SerializeObject(qTableList));
}


// Load the Q table
public void LoadQTable(int npcNum)
public void LoadQTable(int npcNum, string destinationName)
{
string finalFilePath = Path.Combine(AppDomain.CurrentDomain.BaseDirectory, "qtable" + npcNum + ".json");
if (File.Exists(finalFilePath))
{
var qTableList = JsonConvert.DeserializeObject<List<List<double>>>(File.ReadAllText(finalFilePath));
for (int i = 0; i < numStates; i++)
string finalFilePath = Path.Combine(AppDomain.CurrentDomain.BaseDirectory, "qtable" + npcNum + destinationName + ".json");
if (File.Exists(finalFilePath))
{
for (int j = 0; j < numActions; j++)
var qTableList = JsonConvert.DeserializeObject<List<List<double>>>(File.ReadAllText(finalFilePath));
for (int i = 0; i < numStates; i++)
{
qTable[i, j] = qTableList[i][j];
for (int j = 0; j < numActions; j++)
{
qTable[i, j] = qTableList[i][j];
}
}
}
}
}

public bool[,] LoadGrid()
Expand Down
Git LFS file not shown
Git LFS file not shown
Loading

0 comments on commit 8b467a5

Please sign in to comment.