mirror of
https://github.com/20kaushik02/TCP-RL.git
synced 2025-12-06 06:34:06 +00:00
Agent configs, plot corrections
This commit is contained in:
parent
bb30abcc61
commit
6f762fe049
@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import math
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
@ -29,10 +30,6 @@ parser.add_argument('--steps',
|
||||
type=int,
|
||||
default=100,
|
||||
help='Number of steps, Default 100')
|
||||
parser.add_argument('--debug',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Show debug output 0/1, Default 0')
|
||||
args = parser.parse_args()
|
||||
|
||||
startSim = bool(args.start)
|
||||
@ -41,7 +38,6 @@ maxSteps = int(args.steps)
|
||||
|
||||
port = 5555
|
||||
simTime = maxSteps / 10.0 # seconds
|
||||
stepTime = simTime / 200.0 # seconds
|
||||
seed = 12
|
||||
simArgs = {"--duration": simTime,}
|
||||
|
||||
@ -49,7 +45,7 @@ dashes = "-"*18
|
||||
input("[{}Press Enter to start{}]".format(dashes, dashes))
|
||||
|
||||
# create environment
|
||||
env = ns3env.Ns3Env(port=port, stepTime=stepTime, startSim=startSim, simSeed=seed, simArgs=simArgs)
|
||||
env = ns3env.Ns3Env(port=port, startSim=startSim, simSeed=seed, simArgs=simArgs)
|
||||
|
||||
ob_space = env.observation_space
|
||||
ac_space = env.action_space
|
||||
@ -90,9 +86,6 @@ def modeler(input_size, output_size):
|
||||
# input layer
|
||||
model.add(tf.keras.layers.Dense((input_size + output_size) // 2, input_shape=(input_size,), activation='relu'))
|
||||
|
||||
# hidden layer of mean size of input and output
|
||||
# model.add(tf.keras.layers.Dense((input_size + output_size) // 2, activation='relu'))
|
||||
|
||||
# output layer
|
||||
# maps previous layer of input_size units to output_size units
|
||||
# this is a classifier network
|
||||
@ -120,7 +113,7 @@ model.summary()
|
||||
# initialize decaying epsilon-greedy algorithm
|
||||
# fine-tune to ensure balance of exploration and exploitation
|
||||
epsilon = 1.0
|
||||
epsilon_decay_param = iterationNum * 5
|
||||
epsilon_decay_param = iterationNum * 2
|
||||
min_epsilon = 0.1
|
||||
epsilon_decay = (((epsilon_decay_param*maxSteps) - 1.0) / (epsilon_decay_param*maxSteps))
|
||||
|
||||
@ -132,6 +125,10 @@ rew_history = []
|
||||
cWnd_history = []
|
||||
pred_cWnd_history = []
|
||||
rtt_history = []
|
||||
tp_history = []
|
||||
|
||||
recency = maxSteps // 15
|
||||
|
||||
|
||||
done = False
|
||||
|
||||
@ -170,8 +167,34 @@ for iteration in range(iterationNum):
|
||||
print("\t[*] Exploiting gained knowledge. Selected action: {}".format(action_index), file=w_file)
|
||||
|
||||
# Calculate action
|
||||
# Note: prevent new_cWnd from falling too low to avoid negative values
|
||||
new_cWnd = cWnd + action_mapping[action_index]
|
||||
calc_cWnd = cWnd + action_mapping[action_index]
|
||||
|
||||
# Config 1: no cap
|
||||
# new_cWnd = calc_cWnd
|
||||
|
||||
# Config 2: cap cWnd by half upon congestion
|
||||
# ssThresh is set to half of cWnd when congestion occurs
|
||||
# prevent new_cWnd from falling too low
|
||||
# ssThresh = state[0][0]
|
||||
# new_cWnd = max(init_cWnd, (min(ssThresh, calc_cWnd)))
|
||||
|
||||
# Config 3: if throughput cap detected, fix cWnd
|
||||
# detect cap by checking if recent variance less than 1% of current
|
||||
thresh = state[0][0] # ssThresh
|
||||
if step+1 > recency:
|
||||
tp_dev = math.sqrt(np.var(tp_history[(-recency):]))
|
||||
tp_1per = 0.01 * throughput
|
||||
if tp_dev < tp_1per:
|
||||
thresh = cWnd
|
||||
new_cWnd = max(init_cWnd, (min(thresh, calc_cWnd)))
|
||||
|
||||
# Config 4: detect throughput cap by checking against experimentally determined value
|
||||
# thresh = state[0][0] # ssThresh
|
||||
# if step+1 > recency:
|
||||
# if throughput > 216000: # must be tuned based on bandwidth
|
||||
# thresh = cWnd
|
||||
# new_cWnd = max(init_cWnd, (min(thresh, calc_cWnd)))
|
||||
|
||||
new_ssThresh = int(cWnd/2)
|
||||
actions = [new_ssThresh, new_cWnd]
|
||||
|
||||
@ -183,6 +206,7 @@ for iteration in range(iterationNum):
|
||||
next_state = next_state[4:]
|
||||
cWnd = next_state[1]
|
||||
rtt = next_state[7]
|
||||
throughput = next_state[11]
|
||||
|
||||
print("\t[#] Next state: ", next_state, file=w_file)
|
||||
print("\t[!] Reward: ", reward, file=w_file)
|
||||
@ -215,7 +239,7 @@ for iteration in range(iterationNum):
|
||||
rew_history.append(rewardsum)
|
||||
rtt_history.append(rtt)
|
||||
cWnd_history.append(cWnd)
|
||||
pred_cWnd_history.append(new_cWnd)
|
||||
tp_history.append(throughput)
|
||||
|
||||
print("\n[O] Iteration over.", file=w_file)
|
||||
print("[-] Final epsilon value: ", epsilon, file=w_file)
|
||||
@ -230,19 +254,19 @@ for iteration in range(iterationNum):
|
||||
# break
|
||||
|
||||
mpl.rcdefaults()
|
||||
mpl.rcParams.update({'font.size': 12})
|
||||
mpl.rcParams.update({'font.size': 16})
|
||||
fig, ax = plt.subplots(2, 2, figsize=(4,2))
|
||||
plt.tight_layout(pad=0.3)
|
||||
|
||||
ax[0, 0].plot(range(len(cWnd_history)), cWnd_history, marker="", linestyle="-")
|
||||
ax[0, 0].set_title('Congestion windows')
|
||||
ax[0, 0].set_xlabel('Steps')
|
||||
ax[0, 0].set_ylabel('Actual CWND')
|
||||
ax[0, 0].set_ylabel('CWND (segments)')
|
||||
|
||||
ax[0, 1].plot(range(len(pred_cWnd_history)), pred_cWnd_history, marker="", linestyle="-")
|
||||
ax[0, 1].set_title('Predicted values')
|
||||
ax[0, 1].plot(range(len(tp_history)), tp_history, marker="", linestyle="-")
|
||||
ax[0, 1].set_title('Throughput over time')
|
||||
ax[0, 1].set_xlabel('Steps')
|
||||
ax[0, 1].set_ylabel('Predicted CWND')
|
||||
ax[0, 1].set_ylabel('Throughput (bits)')
|
||||
|
||||
ax[1, 0].plot(range(len(rtt_history)), rtt_history, marker="", linestyle="-")
|
||||
ax[1, 0].set_title('RTT over time')
|
||||
@ -252,7 +276,8 @@ ax[1, 0].set_ylabel('RTT (microseconds)')
|
||||
ax[1, 1].plot(range(len(rew_history)), rew_history, marker="", linestyle="-")
|
||||
ax[1, 1].set_title('Reward sum plot')
|
||||
ax[1, 1].set_xlabel('Steps')
|
||||
ax[1, 1].set_ylabel('Reward sum')
|
||||
ax[1, 1].set_ylabel('Accumulated reward')
|
||||
|
||||
plt.savefig('plots.png')
|
||||
plt.show()
|
||||
|
||||
|
||||
9
sim.cc
9
sim.cc
@ -76,6 +76,9 @@ int main (int argc, char *argv[])
|
||||
std::string queue_disc_type = "ns3::PfifoFastQueueDisc";
|
||||
std::string recovery = "ns3::TcpClassicRecovery";
|
||||
|
||||
double rew = 1.0;
|
||||
double pen = -1.0;
|
||||
|
||||
CommandLine cmd;
|
||||
// required parameters for OpenGym interface
|
||||
cmd.AddValue ("openGymPort", "Port number for OpenGym env. Default: 5555", openGymPort);
|
||||
@ -101,6 +104,8 @@ int main (int argc, char *argv[])
|
||||
cmd.AddValue ("queue_disc_type", "Queue disc type for gateway (e.g. ns3::CoDelQueueDisc)", queue_disc_type);
|
||||
cmd.AddValue ("sack", "Enable or disable SACK option", sack);
|
||||
cmd.AddValue ("recovery", "Recovery algorithm type to use (e.g., ns3::TcpPrrRecovery", recovery);
|
||||
cmd.AddValue ("reward", "Agent reward value", rew);
|
||||
cmd.AddValue ("penalty", "Agent penalty value", pen);
|
||||
cmd.Parse (argc, argv);
|
||||
|
||||
transport_prot = std::string ("ns3::") + transport_prot;
|
||||
@ -135,8 +140,8 @@ int main (int argc, char *argv[])
|
||||
openGymInterface = OpenGymInterface::Get(openGymPort);
|
||||
Config::SetDefault ("ns3::TcpRlTimeBased::StepTime", TimeValue (Seconds(tcpEnvTimeStep))); // Time step of env
|
||||
Config::SetDefault ("ns3::TcpRlTimeBased::Duration", TimeValue (Seconds(duration))); // Duration of env sim
|
||||
Config::SetDefault ("ns3::TcpRlTimeBased::Reward", DoubleValue (1.0)); // Reward
|
||||
Config::SetDefault ("ns3::TcpRlTimeBased::Penalty", DoubleValue (-1.0)); // Penalty
|
||||
Config::SetDefault ("ns3::TcpRlTimeBased::Reward", DoubleValue (rew)); // Reward
|
||||
Config::SetDefault ("ns3::TcpRlTimeBased::Penalty", DoubleValue (pen)); // Penalty
|
||||
}
|
||||
|
||||
// Calculate the ADU size
|
||||
|
||||
@ -531,6 +531,27 @@ TcpTimeStepGymEnv::GetObservation()
|
||||
}
|
||||
box->AddValue(avgRtt.GetMicroSeconds ());
|
||||
|
||||
//m_minRtt
|
||||
box->AddValue(m_tcb->m_minRtt.GetMicroSeconds ());
|
||||
|
||||
//avgInterTx
|
||||
Time avgInterTx = Seconds(0.0);
|
||||
if (m_interTxTimeNum) {
|
||||
avgInterTx = m_interTxTimeSum / m_interTxTimeNum;
|
||||
}
|
||||
box->AddValue(avgInterTx.GetMicroSeconds ());
|
||||
|
||||
//avgInterRx
|
||||
Time avgInterRx = Seconds(0.0);
|
||||
if (m_interRxTimeNum) {
|
||||
avgInterRx = m_interRxTimeSum / m_interRxTimeNum;
|
||||
}
|
||||
box->AddValue(avgInterRx.GetMicroSeconds ());
|
||||
|
||||
//throughput bytes/s
|
||||
float throughput = (segmentsAckedSum * m_tcb->m_segmentSize) / m_timeStep.GetSeconds();
|
||||
box->AddValue(throughput);
|
||||
|
||||
/*---------------------------------------------------------------------------------------------------*/
|
||||
/*---------------------------------------------------------------------------------------------------*/
|
||||
/*---------------------------------------------------------------------------------------------------*/
|
||||
@ -563,27 +584,6 @@ TcpTimeStepGymEnv::GetObservation()
|
||||
/*---------------------------------------------------------------------------------------------------*/
|
||||
/*---------------------------------------------------------------------------------------------------*/
|
||||
|
||||
//m_minRtt
|
||||
box->AddValue(m_tcb->m_minRtt.GetMicroSeconds ());
|
||||
|
||||
//avgInterTx
|
||||
Time avgInterTx = Seconds(0.0);
|
||||
if (m_interTxTimeNum) {
|
||||
avgInterTx = m_interTxTimeSum / m_interTxTimeNum;
|
||||
}
|
||||
box->AddValue(avgInterTx.GetMicroSeconds ());
|
||||
|
||||
//avgInterRx
|
||||
Time avgInterRx = Seconds(0.0);
|
||||
if (m_interRxTimeNum) {
|
||||
avgInterRx = m_interRxTimeSum / m_interRxTimeNum;
|
||||
}
|
||||
box->AddValue(avgInterRx.GetMicroSeconds ());
|
||||
|
||||
//throughput bytes/s
|
||||
float throughput = (segmentsAckedSum * m_tcb->m_segmentSize) / m_timeStep.GetSeconds();
|
||||
box->AddValue(throughput);
|
||||
|
||||
// Print data
|
||||
NS_LOG_INFO ("MyGetObservation: " << box);
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user