Agent configs, plot corrections

This commit is contained in:
Kaushik Narayan R 2022-05-15 18:59:46 +05:30
parent bb30abcc61
commit 6f762fe049
3 changed files with 72 additions and 42 deletions

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import math
import sys import sys
import argparse import argparse
@ -29,10 +30,6 @@ parser.add_argument('--steps',
type=int, type=int,
default=100, default=100,
help='Number of steps, Default 100') help='Number of steps, Default 100')
parser.add_argument('--debug',
type=int,
default=0,
help='Show debug output 0/1, Default 0')
args = parser.parse_args() args = parser.parse_args()
startSim = bool(args.start) startSim = bool(args.start)
@ -41,7 +38,6 @@ maxSteps = int(args.steps)
port = 5555 port = 5555
simTime = maxSteps / 10.0 # seconds simTime = maxSteps / 10.0 # seconds
stepTime = simTime / 200.0 # seconds
seed = 12 seed = 12
simArgs = {"--duration": simTime,} simArgs = {"--duration": simTime,}
@ -49,7 +45,7 @@ dashes = "-"*18
input("[{}Press Enter to start{}]".format(dashes, dashes)) input("[{}Press Enter to start{}]".format(dashes, dashes))
# create environment # create environment
env = ns3env.Ns3Env(port=port, stepTime=stepTime, startSim=startSim, simSeed=seed, simArgs=simArgs) env = ns3env.Ns3Env(port=port, startSim=startSim, simSeed=seed, simArgs=simArgs)
ob_space = env.observation_space ob_space = env.observation_space
ac_space = env.action_space ac_space = env.action_space
@ -90,9 +86,6 @@ def modeler(input_size, output_size):
# input layer # input layer
model.add(tf.keras.layers.Dense((input_size + output_size) // 2, input_shape=(input_size,), activation='relu')) model.add(tf.keras.layers.Dense((input_size + output_size) // 2, input_shape=(input_size,), activation='relu'))
# hidden layer of mean size of input and output
# model.add(tf.keras.layers.Dense((input_size + output_size) // 2, activation='relu'))
# output layer # output layer
# maps previous layer of input_size units to output_size units # maps previous layer of input_size units to output_size units
# this is a classifier network # this is a classifier network
@ -120,7 +113,7 @@ model.summary()
# initialize decaying epsilon-greedy algorithm # initialize decaying epsilon-greedy algorithm
# fine-tune to ensure balance of exploration and exploitation # fine-tune to ensure balance of exploration and exploitation
epsilon = 1.0 epsilon = 1.0
epsilon_decay_param = iterationNum * 5 epsilon_decay_param = iterationNum * 2
min_epsilon = 0.1 min_epsilon = 0.1
epsilon_decay = (((epsilon_decay_param*maxSteps) - 1.0) / (epsilon_decay_param*maxSteps)) epsilon_decay = (((epsilon_decay_param*maxSteps) - 1.0) / (epsilon_decay_param*maxSteps))
@ -132,6 +125,10 @@ rew_history = []
cWnd_history = [] cWnd_history = []
pred_cWnd_history = [] pred_cWnd_history = []
rtt_history = [] rtt_history = []
tp_history = []
recency = maxSteps // 15
done = False done = False
@ -170,8 +167,34 @@ for iteration in range(iterationNum):
print("\t[*] Exploiting gained knowledge. Selected action: {}".format(action_index), file=w_file) print("\t[*] Exploiting gained knowledge. Selected action: {}".format(action_index), file=w_file)
# Calculate action # Calculate action
# Note: prevent new_cWnd from falling too low to avoid negative values calc_cWnd = cWnd + action_mapping[action_index]
new_cWnd = cWnd + action_mapping[action_index]
# Config 1: no cap
# new_cWnd = calc_cWnd
# Config 2: cap cWnd by half upon congestion
# ssThresh is set to half of cWnd when congestion occurs
# prevent new_cWnd from falling too low
# ssThresh = state[0][0]
# new_cWnd = max(init_cWnd, (min(ssThresh, calc_cWnd)))
# Config 3: if throughput cap detected, fix cWnd
# detect cap by checking if recent variance less than 1% of current
thresh = state[0][0] # ssThresh
if step+1 > recency:
tp_dev = math.sqrt(np.var(tp_history[(-recency):]))
tp_1per = 0.01 * throughput
if tp_dev < tp_1per:
thresh = cWnd
new_cWnd = max(init_cWnd, (min(thresh, calc_cWnd)))
# Config 4: detect throughput cap by checking against experimentally determined value
# thresh = state[0][0] # ssThresh
# if step+1 > recency:
# if throughput > 216000: # must be tuned based on bandwidth
# thresh = cWnd
# new_cWnd = max(init_cWnd, (min(thresh, calc_cWnd)))
new_ssThresh = int(cWnd/2) new_ssThresh = int(cWnd/2)
actions = [new_ssThresh, new_cWnd] actions = [new_ssThresh, new_cWnd]
@ -183,6 +206,7 @@ for iteration in range(iterationNum):
next_state = next_state[4:] next_state = next_state[4:]
cWnd = next_state[1] cWnd = next_state[1]
rtt = next_state[7] rtt = next_state[7]
throughput = next_state[11]
print("\t[#] Next state: ", next_state, file=w_file) print("\t[#] Next state: ", next_state, file=w_file)
print("\t[!] Reward: ", reward, file=w_file) print("\t[!] Reward: ", reward, file=w_file)
@ -215,7 +239,7 @@ for iteration in range(iterationNum):
rew_history.append(rewardsum) rew_history.append(rewardsum)
rtt_history.append(rtt) rtt_history.append(rtt)
cWnd_history.append(cWnd) cWnd_history.append(cWnd)
pred_cWnd_history.append(new_cWnd) tp_history.append(throughput)
print("\n[O] Iteration over.", file=w_file) print("\n[O] Iteration over.", file=w_file)
print("[-] Final epsilon value: ", epsilon, file=w_file) print("[-] Final epsilon value: ", epsilon, file=w_file)
@ -230,19 +254,19 @@ for iteration in range(iterationNum):
# break # break
mpl.rcdefaults() mpl.rcdefaults()
mpl.rcParams.update({'font.size': 12}) mpl.rcParams.update({'font.size': 16})
fig, ax = plt.subplots(2, 2, figsize=(4,2)) fig, ax = plt.subplots(2, 2, figsize=(4,2))
plt.tight_layout(pad=0.3) plt.tight_layout(pad=0.3)
ax[0, 0].plot(range(len(cWnd_history)), cWnd_history, marker="", linestyle="-") ax[0, 0].plot(range(len(cWnd_history)), cWnd_history, marker="", linestyle="-")
ax[0, 0].set_title('Congestion windows') ax[0, 0].set_title('Congestion windows')
ax[0, 0].set_xlabel('Steps') ax[0, 0].set_xlabel('Steps')
ax[0, 0].set_ylabel('Actual CWND') ax[0, 0].set_ylabel('CWND (segments)')
ax[0, 1].plot(range(len(pred_cWnd_history)), pred_cWnd_history, marker="", linestyle="-") ax[0, 1].plot(range(len(tp_history)), tp_history, marker="", linestyle="-")
ax[0, 1].set_title('Predicted values') ax[0, 1].set_title('Throughput over time')
ax[0, 1].set_xlabel('Steps') ax[0, 1].set_xlabel('Steps')
ax[0, 1].set_ylabel('Predicted CWND') ax[0, 1].set_ylabel('Throughput (bits)')
ax[1, 0].plot(range(len(rtt_history)), rtt_history, marker="", linestyle="-") ax[1, 0].plot(range(len(rtt_history)), rtt_history, marker="", linestyle="-")
ax[1, 0].set_title('RTT over time') ax[1, 0].set_title('RTT over time')
@ -252,7 +276,8 @@ ax[1, 0].set_ylabel('RTT (microseconds)')
ax[1, 1].plot(range(len(rew_history)), rew_history, marker="", linestyle="-") ax[1, 1].plot(range(len(rew_history)), rew_history, marker="", linestyle="-")
ax[1, 1].set_title('Reward sum plot') ax[1, 1].set_title('Reward sum plot')
ax[1, 1].set_xlabel('Steps') ax[1, 1].set_xlabel('Steps')
ax[1, 1].set_ylabel('Reward sum') ax[1, 1].set_ylabel('Accumulated reward')
plt.savefig('plots.png')
plt.show() plt.show()

9
sim.cc
View File

@ -76,6 +76,9 @@ int main (int argc, char *argv[])
std::string queue_disc_type = "ns3::PfifoFastQueueDisc"; std::string queue_disc_type = "ns3::PfifoFastQueueDisc";
std::string recovery = "ns3::TcpClassicRecovery"; std::string recovery = "ns3::TcpClassicRecovery";
double rew = 1.0;
double pen = -1.0;
CommandLine cmd; CommandLine cmd;
// required parameters for OpenGym interface // required parameters for OpenGym interface
cmd.AddValue ("openGymPort", "Port number for OpenGym env. Default: 5555", openGymPort); cmd.AddValue ("openGymPort", "Port number for OpenGym env. Default: 5555", openGymPort);
@ -101,6 +104,8 @@ int main (int argc, char *argv[])
cmd.AddValue ("queue_disc_type", "Queue disc type for gateway (e.g. ns3::CoDelQueueDisc)", queue_disc_type); cmd.AddValue ("queue_disc_type", "Queue disc type for gateway (e.g. ns3::CoDelQueueDisc)", queue_disc_type);
cmd.AddValue ("sack", "Enable or disable SACK option", sack); cmd.AddValue ("sack", "Enable or disable SACK option", sack);
cmd.AddValue ("recovery", "Recovery algorithm type to use (e.g., ns3::TcpPrrRecovery", recovery); cmd.AddValue ("recovery", "Recovery algorithm type to use (e.g., ns3::TcpPrrRecovery", recovery);
cmd.AddValue ("reward", "Agent reward value", rew);
cmd.AddValue ("penalty", "Agent penalty value", pen);
cmd.Parse (argc, argv); cmd.Parse (argc, argv);
transport_prot = std::string ("ns3::") + transport_prot; transport_prot = std::string ("ns3::") + transport_prot;
@ -135,8 +140,8 @@ int main (int argc, char *argv[])
openGymInterface = OpenGymInterface::Get(openGymPort); openGymInterface = OpenGymInterface::Get(openGymPort);
Config::SetDefault ("ns3::TcpRlTimeBased::StepTime", TimeValue (Seconds(tcpEnvTimeStep))); // Time step of env Config::SetDefault ("ns3::TcpRlTimeBased::StepTime", TimeValue (Seconds(tcpEnvTimeStep))); // Time step of env
Config::SetDefault ("ns3::TcpRlTimeBased::Duration", TimeValue (Seconds(duration))); // Duration of env sim Config::SetDefault ("ns3::TcpRlTimeBased::Duration", TimeValue (Seconds(duration))); // Duration of env sim
Config::SetDefault ("ns3::TcpRlTimeBased::Reward", DoubleValue (1.0)); // Reward Config::SetDefault ("ns3::TcpRlTimeBased::Reward", DoubleValue (rew)); // Reward
Config::SetDefault ("ns3::TcpRlTimeBased::Penalty", DoubleValue (-1.0)); // Penalty Config::SetDefault ("ns3::TcpRlTimeBased::Penalty", DoubleValue (pen)); // Penalty
} }
// Calculate the ADU size // Calculate the ADU size

View File

@ -531,6 +531,27 @@ TcpTimeStepGymEnv::GetObservation()
} }
box->AddValue(avgRtt.GetMicroSeconds ()); box->AddValue(avgRtt.GetMicroSeconds ());
//m_minRtt
box->AddValue(m_tcb->m_minRtt.GetMicroSeconds ());
//avgInterTx
Time avgInterTx = Seconds(0.0);
if (m_interTxTimeNum) {
avgInterTx = m_interTxTimeSum / m_interTxTimeNum;
}
box->AddValue(avgInterTx.GetMicroSeconds ());
//avgInterRx
Time avgInterRx = Seconds(0.0);
if (m_interRxTimeNum) {
avgInterRx = m_interRxTimeSum / m_interRxTimeNum;
}
box->AddValue(avgInterRx.GetMicroSeconds ());
//throughput bytes/s
float throughput = (segmentsAckedSum * m_tcb->m_segmentSize) / m_timeStep.GetSeconds();
box->AddValue(throughput);
/*---------------------------------------------------------------------------------------------------*/ /*---------------------------------------------------------------------------------------------------*/
/*---------------------------------------------------------------------------------------------------*/ /*---------------------------------------------------------------------------------------------------*/
/*---------------------------------------------------------------------------------------------------*/ /*---------------------------------------------------------------------------------------------------*/
@ -563,27 +584,6 @@ TcpTimeStepGymEnv::GetObservation()
/*---------------------------------------------------------------------------------------------------*/ /*---------------------------------------------------------------------------------------------------*/
/*---------------------------------------------------------------------------------------------------*/ /*---------------------------------------------------------------------------------------------------*/
//m_minRtt
box->AddValue(m_tcb->m_minRtt.GetMicroSeconds ());
//avgInterTx
Time avgInterTx = Seconds(0.0);
if (m_interTxTimeNum) {
avgInterTx = m_interTxTimeSum / m_interTxTimeNum;
}
box->AddValue(avgInterTx.GetMicroSeconds ());
//avgInterRx
Time avgInterRx = Seconds(0.0);
if (m_interRxTimeNum) {
avgInterRx = m_interRxTimeSum / m_interRxTimeNum;
}
box->AddValue(avgInterRx.GetMicroSeconds ());
//throughput bytes/s
float throughput = (segmentsAckedSum * m_tcb->m_segmentSize) / m_timeStep.GetSeconds();
box->AddValue(throughput);
// Print data // Print data
NS_LOG_INFO ("MyGetObservation: " << box); NS_LOG_INFO ("MyGetObservation: " << box);