mirror of
https://github.com/20kaushik02/TCP-RL.git
synced 2025-12-06 07:54:07 +00:00
Upload files
This commit is contained in:
parent
2d77152b5d
commit
44717b8f3e
258
TCP-RL-Agent.py
Executable file
258
TCP-RL-Agent.py
Executable file
@ -0,0 +1,258 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import matplotlib as mpl
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from ns3gym import ns3env
|
||||
from tcp_base import TcpTimeBased, TcpEventBased
|
||||
|
||||
try:
|
||||
w_file = open('run.log', 'w')
|
||||
except:
|
||||
w_file = sys.stdout
|
||||
parser = argparse.ArgumentParser(description='Start simulation script on/off')
|
||||
parser.add_argument('--start',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Start ns-3 simulation script 0/1, Default: 1')
|
||||
parser.add_argument('--iterations',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Number of iterations, Default: 1')
|
||||
parser.add_argument('--steps',
|
||||
type=int,
|
||||
default=100,
|
||||
help='Number of steps, Default 100')
|
||||
parser.add_argument('--debug',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Show debug output 0/1, Default 0')
|
||||
args = parser.parse_args()
|
||||
|
||||
startSim = bool(args.start)
|
||||
iterationNum = int(args.iterations)
|
||||
maxSteps = int(args.steps)
|
||||
|
||||
port = 5555
|
||||
simTime = maxSteps / 10.0 # seconds
|
||||
stepTime = simTime / 200.0 # seconds
|
||||
seed = 12
|
||||
simArgs = {"--duration": simTime,}
|
||||
|
||||
dashes = "-"*18
|
||||
input("[{}Press Enter to start{}]".format(dashes, dashes))
|
||||
|
||||
# create environment
|
||||
env = ns3env.Ns3Env(port=port, stepTime=stepTime, startSim=startSim, simSeed=seed, simArgs=simArgs)
|
||||
|
||||
ob_space = env.observation_space
|
||||
ac_space = env.action_space
|
||||
|
||||
# TODO: right now, the next action is selected inside the loop, rather than using get_action.
|
||||
# this is because we use the decaying epsilon-greedy algo which needs to use the live model
|
||||
# somehow change or put that logic in an `RLTCP` class that inherits from the Tcp class, like in tcp_base.py,
|
||||
# then move the class to tcp_base.py and use that agent here
|
||||
def get_agent(state):
|
||||
socketUuid = state[0]
|
||||
tcpEnvType = state[1]
|
||||
tcpAgent = get_agent.tcpAgents.get(socketUuid, None)
|
||||
if tcpAgent is None:
|
||||
# get a new agent based on the selected env type
|
||||
if tcpEnvType == 0:
|
||||
# event-based = 0
|
||||
tcpAgent = TcpEventBased()
|
||||
else:
|
||||
# time-based = 1
|
||||
tcpAgent = TcpTimeBased()
|
||||
tcpAgent.set_spaces(get_agent.ob_space, get_agent.ac_space)
|
||||
get_agent.tcpAgents[socketUuid] = tcpAgent
|
||||
|
||||
return tcpAgent
|
||||
|
||||
# initialize agent variables
|
||||
# (useless until the above todo is fixed)
|
||||
get_agent.tcpAgents = {}
|
||||
get_agent.ob_space = ob_space
|
||||
get_agent.ac_space = ac_space
|
||||
|
||||
def modeler(input_size, output_size):
|
||||
"""
|
||||
Designs a fully connected neural network.
|
||||
"""
|
||||
model = tf.keras.Sequential()
|
||||
|
||||
# input layer
|
||||
model.add(tf.keras.layers.Dense((input_size + output_size) // 2, input_shape=(input_size,), activation='relu'))
|
||||
|
||||
# hidden layer of mean size of input and output
|
||||
# model.add(tf.keras.layers.Dense((input_size + output_size) // 2, activation='relu'))
|
||||
|
||||
# output layer
|
||||
# maps previous layer of input_size units to output_size units
|
||||
# this is a classifier network
|
||||
model.add(tf.keras.layers.Dense(output_size, activation='softmax'))
|
||||
|
||||
return model
|
||||
|
||||
state_size = ob_space.shape[0] - 4 # ignoring 4 env attributes
|
||||
|
||||
action_size = 3
|
||||
action_mapping = {} # dict faster than list
|
||||
action_mapping[0] = 0
|
||||
action_mapping[1] = 600
|
||||
action_mapping[2] = -150
|
||||
|
||||
# build model
|
||||
model = modeler(state_size, action_size)
|
||||
model.compile(
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-2),
|
||||
loss='categorical_crossentropy',
|
||||
metrics=['accuracy']
|
||||
)
|
||||
model.summary()
|
||||
|
||||
# initialize decaying epsilon-greedy algorithm
|
||||
# fine-tune to ensure balance of exploration and exploitation
|
||||
epsilon = 1.0
|
||||
epsilon_decay_param = iterationNum * 5
|
||||
min_epsilon = 0.1
|
||||
epsilon_decay = (((epsilon_decay_param*maxSteps) - 1.0) / (epsilon_decay_param*maxSteps))
|
||||
|
||||
# initialize Q-learning's discount factor
|
||||
discount_factor = 0.95
|
||||
|
||||
rewardsum = 0
|
||||
rew_history = []
|
||||
cWnd_history = []
|
||||
pred_cWnd_history = []
|
||||
rtt_history = []
|
||||
|
||||
done = False
|
||||
|
||||
pretty_slash = ['\\', '|', '/', '-']
|
||||
|
||||
for iteration in range(iterationNum):
|
||||
# set initial state
|
||||
state = env.reset()
|
||||
# ignore env attributes: socketID, env type, sim time, nodeID
|
||||
state = state[4:]
|
||||
|
||||
cWnd = state[1]
|
||||
init_cWnd = cWnd
|
||||
|
||||
state = np.reshape(state, [1, state_size])
|
||||
try:
|
||||
for step in range(maxSteps):
|
||||
pretty_index = step % 4
|
||||
print("\r{}\r[{}] Logging to file {} {}".format(
|
||||
' '*(25+len(w_file.name)),
|
||||
pretty_slash[pretty_index],
|
||||
w_file.name,
|
||||
'.'*(pretty_index+1)
|
||||
), end='')
|
||||
|
||||
print("[+] Step: {}".format(step+1), file=w_file)
|
||||
|
||||
# Epsilon-greedy selection
|
||||
if step == 0 or np.random.rand(1) < epsilon:
|
||||
# explore new situation
|
||||
action_index = np.random.randint(0, action_size)
|
||||
print("\t[*] Random exploration. Selected action: {}".format(action_index), file=w_file)
|
||||
else:
|
||||
# exploit gained knowledge
|
||||
action_index = np.argmax(model.predict(state)[0])
|
||||
print("\t[*] Exploiting gained knowledge. Selected action: {}".format(action_index), file=w_file)
|
||||
|
||||
# Calculate action
|
||||
# Note: prevent new_cWnd from falling too low to avoid negative values
|
||||
new_cWnd = cWnd + action_mapping[action_index]
|
||||
new_ssThresh = int(cWnd/2)
|
||||
actions = [new_ssThresh, new_cWnd]
|
||||
|
||||
# Take action step on environment and get feedback
|
||||
next_state, reward, done, _ = env.step(actions)
|
||||
|
||||
rewardsum += reward
|
||||
|
||||
next_state = next_state[4:]
|
||||
cWnd = next_state[1]
|
||||
rtt = next_state[7]
|
||||
|
||||
print("\t[#] Next state: ", next_state, file=w_file)
|
||||
print("\t[!] Reward: ", reward, file=w_file)
|
||||
|
||||
next_state = np.reshape(next_state, [1, state_size])
|
||||
|
||||
|
||||
# Train incrementally
|
||||
# DQN - function approximation using neural networks
|
||||
target = reward
|
||||
if not done:
|
||||
target = (reward + discount_factor * np.amax(model.predict(next_state)[0]))
|
||||
target_f = model.predict(state)
|
||||
target_f[0][action_index] = target
|
||||
model.fit(state, target_f, epochs=1, verbose=0)
|
||||
|
||||
# Update state
|
||||
state = next_state
|
||||
|
||||
if done:
|
||||
print("[X] Stopping: step: {}, reward sum: {}, epsilon: {:.2}"
|
||||
.format(step+1, rewardsum, epsilon),
|
||||
file=w_file)
|
||||
break
|
||||
|
||||
if epsilon > min_epsilon:
|
||||
epsilon *= epsilon_decay
|
||||
|
||||
# Record information
|
||||
rew_history.append(rewardsum)
|
||||
rtt_history.append(rtt)
|
||||
cWnd_history.append(cWnd)
|
||||
pred_cWnd_history.append(new_cWnd)
|
||||
|
||||
print("\n[O] Iteration over.", file=w_file)
|
||||
print("[-] Final epsilon value: ", epsilon, file=w_file)
|
||||
print("[-] Final reward sum: ", rewardsum, file=w_file)
|
||||
print()
|
||||
|
||||
finally:
|
||||
print()
|
||||
if iteration+1 == iterationNum:
|
||||
break
|
||||
# if str(input("[?] Continue to next iteration? [Y/n]: ") or "Y").lower() != "y":
|
||||
# break
|
||||
|
||||
mpl.rcdefaults()
|
||||
mpl.rcParams.update({'font.size': 12})
|
||||
fig, ax = plt.subplots(2, 2, figsize=(4,2))
|
||||
plt.tight_layout(pad=0.3)
|
||||
|
||||
ax[0, 0].plot(range(len(cWnd_history)), cWnd_history, marker="", linestyle="-")
|
||||
ax[0, 0].set_title('Congestion windows')
|
||||
ax[0, 0].set_xlabel('Steps')
|
||||
ax[0, 0].set_ylabel('Actual CWND')
|
||||
|
||||
ax[0, 1].plot(range(len(pred_cWnd_history)), pred_cWnd_history, marker="", linestyle="-")
|
||||
ax[0, 1].set_title('Predicted values')
|
||||
ax[0, 1].set_xlabel('Steps')
|
||||
ax[0, 1].set_ylabel('Predicted CWND')
|
||||
|
||||
ax[1, 0].plot(range(len(rtt_history)), rtt_history, marker="", linestyle="-")
|
||||
ax[1, 0].set_title('RTT over time')
|
||||
ax[1, 0].set_xlabel('Steps')
|
||||
ax[1, 0].set_ylabel('RTT (microseconds)')
|
||||
|
||||
ax[1, 1].plot(range(len(rew_history)), rew_history, marker="", linestyle="-")
|
||||
ax[1, 1].set_title('Reward sum plot')
|
||||
ax[1, 1].set_xlabel('Steps')
|
||||
ax[1, 1].set_ylabel('Reward sum')
|
||||
|
||||
plt.show()
|
||||
|
||||
333
sim.cc
Normal file
333
sim.cc
Normal file
@ -0,0 +1,333 @@
|
||||
/* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */
|
||||
/*
|
||||
* Copyright (c) 2018 Piotr Gawlowicz
|
||||
*
|
||||
* This program is free software; you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License version 2 as
|
||||
* published by the Free Software Foundation;
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program; if not, write to the Free Software
|
||||
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
|
||||
*
|
||||
* Author: Piotr Gawlowicz <gawlowicz.p@gmail.com>
|
||||
* Based on script: ./examples/tcp/tcp-variants-comparison.cc
|
||||
*
|
||||
* Topology:
|
||||
*
|
||||
* Right Leafs (Clients) Left Leafs (Sinks)
|
||||
* | \ / |
|
||||
* | \ bottleneck / |
|
||||
* | R0--------------R1 |
|
||||
* | / \ |
|
||||
* | access / \ access |
|
||||
* N ----------- --------N
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
|
||||
#include "ns3/core-module.h"
|
||||
#include "ns3/network-module.h"
|
||||
#include "ns3/internet-module.h"
|
||||
#include "ns3/point-to-point-module.h"
|
||||
#include "ns3/point-to-point-layout-module.h"
|
||||
#include "ns3/applications-module.h"
|
||||
#include "ns3/error-model.h"
|
||||
#include "ns3/tcp-header.h"
|
||||
#include "ns3/enum.h"
|
||||
#include "ns3/event-id.h"
|
||||
#include "ns3/flow-monitor-helper.h"
|
||||
#include "ns3/ipv4-global-routing-helper.h"
|
||||
#include "ns3/traffic-control-module.h"
|
||||
|
||||
#include "ns3/opengym-module.h"
|
||||
#include "tcp-rl.h"
|
||||
|
||||
using namespace ns3;
|
||||
|
||||
NS_LOG_COMPONENT_DEFINE ("TcpVariantsComparison");
|
||||
|
||||
static std::vector<uint32_t> rxPkts;
|
||||
|
||||
static void
|
||||
CountRxPkts(uint32_t sinkId, Ptr<const Packet> packet, const Address & srcAddr)
|
||||
{
|
||||
rxPkts[sinkId]++;
|
||||
}
|
||||
|
||||
static void
|
||||
PrintRxCount()
|
||||
{
|
||||
uint32_t size = rxPkts.size();
|
||||
NS_LOG_UNCOND("RxPkts:");
|
||||
for (uint32_t i=0; i<size; i++){
|
||||
NS_LOG_UNCOND("---SinkId: "<< i << " RxPkts: " << rxPkts.at(i));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
int main (int argc, char *argv[])
|
||||
{
|
||||
uint32_t openGymPort = 5555;
|
||||
double tcpEnvTimeStep = 0.1;
|
||||
|
||||
uint32_t nLeaf = 1;
|
||||
std::string transport_prot = "TcpRl";
|
||||
|
||||
double error_p = 0.0;
|
||||
std::string bottleneck_bandwidth = "2Mbps";
|
||||
std::string bottleneck_delay = "0.01ms";
|
||||
std::string access_bandwidth = "10Mbps";
|
||||
std::string access_delay = "20ms";
|
||||
std::string prefix_file_name = "TcpVariantsComparison";
|
||||
uint64_t data_mbytes = 0;
|
||||
uint32_t mtu_bytes = 400;
|
||||
double duration = 10.0;
|
||||
uint32_t run = 0;
|
||||
bool flow_monitor = false;
|
||||
bool sack = true;
|
||||
std::string queue_disc_type = "ns3::PfifoFastQueueDisc";
|
||||
std::string recovery = "ns3::TcpClassicRecovery";
|
||||
|
||||
CommandLine cmd;
|
||||
// required parameters for OpenGym interface
|
||||
cmd.AddValue ("openGymPort", "Port number for OpenGym env. Default: 5555", openGymPort);
|
||||
cmd.AddValue ("simSeed", "Seed for random generator. Default: 1", run);
|
||||
cmd.AddValue ("envTimeStep", "Time step interval for time-based TCP env [s]. Default: 0.1s", tcpEnvTimeStep);
|
||||
// other parameters
|
||||
cmd.AddValue ("nLeaf", "Number of left and right side leaf nodes", nLeaf);
|
||||
cmd.AddValue ("transport_prot", "Transport protocol to use: TcpNewReno, "
|
||||
"TcpHybla, TcpHighSpeed, TcpHtcp, TcpVegas, TcpScalable, TcpVeno, "
|
||||
"TcpBic, TcpYeah, TcpIllinois, TcpWestwood, TcpWestwoodPlus, TcpLedbat, "
|
||||
"TcpLp, TcpRl, TcpRlTimeBased", transport_prot);
|
||||
cmd.AddValue ("error_p", "Packet error rate", error_p);
|
||||
cmd.AddValue ("bottleneck_bandwidth", "Bottleneck bandwidth", bottleneck_bandwidth);
|
||||
cmd.AddValue ("bottleneck_delay", "Bottleneck delay", bottleneck_delay);
|
||||
cmd.AddValue ("access_bandwidth", "Access link bandwidth", access_bandwidth);
|
||||
cmd.AddValue ("access_delay", "Access link delay", access_delay);
|
||||
cmd.AddValue ("prefix_name", "Prefix of output trace file", prefix_file_name);
|
||||
cmd.AddValue ("data", "Number of Megabytes of data to transmit", data_mbytes);
|
||||
cmd.AddValue ("mtu", "Size of IP packets to send in bytes", mtu_bytes);
|
||||
cmd.AddValue ("duration", "Time to allow flows to run in seconds", duration);
|
||||
cmd.AddValue ("run", "Run index (for setting repeatable seeds)", run);
|
||||
cmd.AddValue ("flow_monitor", "Enable flow monitor", flow_monitor);
|
||||
cmd.AddValue ("queue_disc_type", "Queue disc type for gateway (e.g. ns3::CoDelQueueDisc)", queue_disc_type);
|
||||
cmd.AddValue ("sack", "Enable or disable SACK option", sack);
|
||||
cmd.AddValue ("recovery", "Recovery algorithm type to use (e.g., ns3::TcpPrrRecovery", recovery);
|
||||
cmd.Parse (argc, argv);
|
||||
|
||||
transport_prot = std::string ("ns3::") + transport_prot;
|
||||
|
||||
SeedManager::SetSeed (1);
|
||||
SeedManager::SetRun (run);
|
||||
|
||||
NS_LOG_UNCOND("Ns3Env parameters:");
|
||||
if (transport_prot.compare ("ns3::TcpRl") == 0 or transport_prot.compare ("ns3::TcpRlTimeBased") == 0)
|
||||
{
|
||||
NS_LOG_UNCOND("--openGymPort: " << openGymPort);
|
||||
} else {
|
||||
NS_LOG_UNCOND("--openGymPort: No OpenGym");
|
||||
}
|
||||
|
||||
NS_LOG_UNCOND("--seed: " << run);
|
||||
NS_LOG_UNCOND("--Tcp version: " << transport_prot);
|
||||
|
||||
|
||||
|
||||
// OpenGym Env --- has to be created before any other thing
|
||||
Ptr<OpenGymInterface> openGymInterface;
|
||||
if (transport_prot.compare ("ns3::TcpRl") == 0)
|
||||
{
|
||||
openGymInterface = OpenGymInterface::Get(openGymPort);
|
||||
Config::SetDefault ("ns3::TcpRl::Reward", DoubleValue (2.0)); // Reward when increasing congestion window
|
||||
Config::SetDefault ("ns3::TcpRl::Penalty", DoubleValue (-30.0)); // Penalty when decreasing congestion window
|
||||
}
|
||||
|
||||
if (transport_prot.compare ("ns3::TcpRlTimeBased") == 0)
|
||||
{
|
||||
openGymInterface = OpenGymInterface::Get(openGymPort);
|
||||
Config::SetDefault ("ns3::TcpRlTimeBased::StepTime", TimeValue (Seconds(tcpEnvTimeStep))); // Time step of env
|
||||
Config::SetDefault ("ns3::TcpRlTimeBased::Duration", TimeValue (Seconds(duration))); // Duration of env sim
|
||||
Config::SetDefault ("ns3::TcpRlTimeBased::Reward", DoubleValue (1.0)); // Reward
|
||||
Config::SetDefault ("ns3::TcpRlTimeBased::Penalty", DoubleValue (-1.0)); // Penalty
|
||||
}
|
||||
|
||||
// Calculate the ADU size
|
||||
Header* temp_header = new Ipv4Header ();
|
||||
uint32_t ip_header = temp_header->GetSerializedSize ();
|
||||
NS_LOG_LOGIC ("IP Header size is: " << ip_header);
|
||||
delete temp_header;
|
||||
temp_header = new TcpHeader ();
|
||||
uint32_t tcp_header = temp_header->GetSerializedSize ();
|
||||
NS_LOG_LOGIC ("TCP Header size is: " << tcp_header);
|
||||
delete temp_header;
|
||||
uint32_t tcp_adu_size = mtu_bytes - 20 - (ip_header + tcp_header);
|
||||
NS_LOG_LOGIC ("TCP ADU size is: " << tcp_adu_size);
|
||||
|
||||
// Set the simulation start and stop time
|
||||
double start_time = 0.1; // it takes some time to initialise some variables, idk why
|
||||
double stop_time = start_time + duration;
|
||||
|
||||
// 4 MB of TCP buffer
|
||||
Config::SetDefault ("ns3::TcpSocket::RcvBufSize", UintegerValue (1 << 21));
|
||||
Config::SetDefault ("ns3::TcpSocket::SndBufSize", UintegerValue (1 << 21));
|
||||
Config::SetDefault ("ns3::TcpSocketBase::Sack", BooleanValue (sack));
|
||||
// no. of packets received before an ACK is sent (why is the default 2?)
|
||||
Config::SetDefault ("ns3::TcpSocket::DelAckCount", UintegerValue (2));
|
||||
|
||||
|
||||
Config::SetDefault ("ns3::TcpL4Protocol::RecoveryType",
|
||||
TypeIdValue (TypeId::LookupByName (recovery)));
|
||||
// Select TCP variant
|
||||
if (transport_prot.compare ("ns3::TcpWestwoodPlus") == 0)
|
||||
{
|
||||
// TcpWestwoodPlus is not an actual TypeId name; we need TcpWestwood here
|
||||
Config::SetDefault ("ns3::TcpL4Protocol::SocketType", TypeIdValue (TcpWestwood::GetTypeId ()));
|
||||
// the default protocol type in ns3::TcpWestwood is WESTWOOD
|
||||
Config::SetDefault ("ns3::TcpWestwood::ProtocolType", EnumValue (TcpWestwood::WESTWOODPLUS));
|
||||
}
|
||||
else
|
||||
{
|
||||
TypeId tcpTid;
|
||||
NS_ABORT_MSG_UNLESS (TypeId::LookupByNameFailSafe (transport_prot, &tcpTid), "TypeId " << transport_prot << " not found");
|
||||
Config::SetDefault ("ns3::TcpL4Protocol::SocketType", TypeIdValue (TypeId::LookupByName (transport_prot)));
|
||||
}
|
||||
|
||||
// Configure the error model
|
||||
// Here we use RateErrorModel with packet error rate
|
||||
Ptr<UniformRandomVariable> uv = CreateObject<UniformRandomVariable> ();
|
||||
uv->SetStream (50);
|
||||
RateErrorModel error_model;
|
||||
error_model.SetRandomVariable (uv);
|
||||
error_model.SetUnit (RateErrorModel::ERROR_UNIT_PACKET);
|
||||
error_model.SetRate (error_p);
|
||||
|
||||
// Create the point-to-point link helpers
|
||||
PointToPointHelper bottleNeckLink;
|
||||
bottleNeckLink.SetDeviceAttribute ("DataRate", StringValue (bottleneck_bandwidth));
|
||||
bottleNeckLink.SetChannelAttribute ("Delay", StringValue (bottleneck_delay));
|
||||
//bottleNeckLink.SetDeviceAttribute ("ReceiveErrorModel", PointerValue (&error_model));
|
||||
|
||||
PointToPointHelper pointToPointLeaf;
|
||||
pointToPointLeaf.SetDeviceAttribute ("DataRate", StringValue (access_bandwidth));
|
||||
pointToPointLeaf.SetChannelAttribute ("Delay", StringValue (access_delay));
|
||||
|
||||
PointToPointDumbbellHelper d (nLeaf, pointToPointLeaf,
|
||||
nLeaf, pointToPointLeaf,
|
||||
bottleNeckLink);
|
||||
|
||||
// Install IP stack
|
||||
InternetStackHelper stack;
|
||||
stack.InstallAll ();
|
||||
|
||||
// Traffic Control
|
||||
TrafficControlHelper tchPfifo;
|
||||
tchPfifo.SetRootQueueDisc ("ns3::PfifoFastQueueDisc");
|
||||
|
||||
TrafficControlHelper tchCoDel;
|
||||
tchCoDel.SetRootQueueDisc ("ns3::CoDelQueueDisc");
|
||||
|
||||
DataRate access_b (access_bandwidth);
|
||||
DataRate bottle_b (bottleneck_bandwidth);
|
||||
Time access_d (access_delay);
|
||||
Time bottle_d (bottleneck_delay);
|
||||
|
||||
uint32_t size = static_cast<uint32_t>((std::min (access_b, bottle_b).GetBitRate () / 8) *
|
||||
((access_d + bottle_d + access_d) * 2).GetSeconds ());
|
||||
|
||||
Config::SetDefault ("ns3::PfifoFastQueueDisc::MaxSize",
|
||||
QueueSizeValue (QueueSize (QueueSizeUnit::PACKETS, size / mtu_bytes)));
|
||||
Config::SetDefault ("ns3::CoDelQueueDisc::MaxSize",
|
||||
QueueSizeValue (QueueSize (QueueSizeUnit::BYTES, size)));
|
||||
|
||||
if (queue_disc_type.compare ("ns3::PfifoFastQueueDisc") == 0)
|
||||
{
|
||||
tchPfifo.Install (d.GetLeft()->GetDevice(1));
|
||||
tchPfifo.Install (d.GetRight()->GetDevice(1));
|
||||
}
|
||||
else if (queue_disc_type.compare ("ns3::CoDelQueueDisc") == 0)
|
||||
{
|
||||
tchCoDel.Install (d.GetLeft()->GetDevice(1));
|
||||
tchCoDel.Install (d.GetRight()->GetDevice(1));
|
||||
}
|
||||
else
|
||||
{
|
||||
NS_FATAL_ERROR ("Queue not recognized. Allowed values are ns3::CoDelQueueDisc or ns3::PfifoFastQueueDisc");
|
||||
}
|
||||
|
||||
// Assign IP Addresses
|
||||
d.AssignIpv4Addresses (Ipv4AddressHelper ("10.1.1.0", "255.255.255.0"),
|
||||
Ipv4AddressHelper ("10.2.1.0", "255.255.255.0"),
|
||||
Ipv4AddressHelper ("10.3.1.0", "255.255.255.0"));
|
||||
|
||||
|
||||
NS_LOG_INFO ("Initialize Global Routing.");
|
||||
Ipv4GlobalRoutingHelper::PopulateRoutingTables ();
|
||||
|
||||
// Install apps in left and right nodes
|
||||
uint16_t port = 50000;
|
||||
Address sinkLocalAddress (InetSocketAddress (Ipv4Address::GetAny (), port));
|
||||
PacketSinkHelper sinkHelper ("ns3::TcpSocketFactory", sinkLocalAddress);
|
||||
ApplicationContainer sinkApps;
|
||||
for (uint32_t i = 0; i < d.RightCount (); ++i)
|
||||
{
|
||||
sinkHelper.SetAttribute ("Protocol", TypeIdValue (TcpSocketFactory::GetTypeId ()));
|
||||
sinkApps.Add (sinkHelper.Install (d.GetRight (i)));
|
||||
}
|
||||
sinkApps.Start (Seconds (0.0));
|
||||
sinkApps.Stop (Seconds (stop_time));
|
||||
|
||||
for (uint32_t i = 0; i < d.LeftCount (); ++i)
|
||||
{
|
||||
// Create an on/off app sending packets to the left side
|
||||
AddressValue remoteAddress (InetSocketAddress (d.GetRightIpv4Address (i), port));
|
||||
Config::SetDefault ("ns3::TcpSocket::SegmentSize", UintegerValue (tcp_adu_size));
|
||||
BulkSendHelper ftp ("ns3::TcpSocketFactory", Address ());
|
||||
ftp.SetAttribute ("Remote", remoteAddress);
|
||||
ftp.SetAttribute ("SendSize", UintegerValue (tcp_adu_size));
|
||||
ftp.SetAttribute ("MaxBytes", UintegerValue (data_mbytes * 1000000));
|
||||
|
||||
ApplicationContainer clientApp = ftp.Install (d.GetLeft (i));
|
||||
clientApp.Start (Seconds (start_time * i)); // Start after sink
|
||||
clientApp.Stop (Seconds (stop_time - 3)); // Stop before the sink
|
||||
}
|
||||
|
||||
// Flow monitor
|
||||
FlowMonitorHelper flowHelper;
|
||||
if (flow_monitor)
|
||||
{
|
||||
flowHelper.InstallAll ();
|
||||
}
|
||||
|
||||
// Count RX packets
|
||||
for (uint32_t i = 0; i < d.RightCount (); ++i)
|
||||
{
|
||||
rxPkts.push_back(0);
|
||||
Ptr<PacketSink> pktSink = DynamicCast<PacketSink>(sinkApps.Get(i));
|
||||
pktSink->TraceConnectWithoutContext ("Rx", MakeBoundCallback (&CountRxPkts, i));
|
||||
}
|
||||
|
||||
Simulator::Stop (Seconds (stop_time));
|
||||
Simulator::Run ();
|
||||
|
||||
if (flow_monitor)
|
||||
{
|
||||
flowHelper.SerializeToXmlFile (prefix_file_name + ".flowmonitor", true, true);
|
||||
}
|
||||
|
||||
if (transport_prot.compare ("ns3::TcpRl") == 0 or transport_prot.compare ("ns3::TcpRlTimeBased") == 0)
|
||||
{
|
||||
openGymInterface->NotifySimulationEnd();
|
||||
}
|
||||
|
||||
PrintRxCount();
|
||||
Simulator::Destroy ();
|
||||
return 0;
|
||||
}
|
||||
714
tcp-rl-env.cc
Normal file
714
tcp-rl-env.cc
Normal file
@ -0,0 +1,714 @@
|
||||
/* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */
|
||||
/*
|
||||
* Copyright (c) 2018 Technische Universität Berlin
|
||||
*
|
||||
* This program is free software; you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License version 2 as
|
||||
* published by the Free Software Foundation;
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program; if not, write to the Free Software
|
||||
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
|
||||
*
|
||||
* Author: Piotr Gawlowicz <gawlowicz@tkn.tu-berlin.de>
|
||||
*/
|
||||
|
||||
#include "tcp-rl-env.h"
|
||||
#include "ns3/tcp-header.h"
|
||||
#include "ns3/object.h"
|
||||
#include "ns3/core-module.h"
|
||||
#include "ns3/log.h"
|
||||
#include "ns3/simulator.h"
|
||||
#include "ns3/tcp-socket-base.h"
|
||||
#include <vector>
|
||||
#include <numeric>
|
||||
|
||||
|
||||
namespace ns3 {
|
||||
|
||||
NS_LOG_COMPONENT_DEFINE ("ns3::TcpGymEnv");
|
||||
NS_OBJECT_ENSURE_REGISTERED (TcpGymEnv);
|
||||
|
||||
TcpGymEnv::TcpGymEnv ()
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
SetOpenGymInterface(OpenGymInterface::Get());
|
||||
}
|
||||
|
||||
TcpGymEnv::~TcpGymEnv ()
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
}
|
||||
|
||||
TypeId
|
||||
TcpGymEnv::GetTypeId (void)
|
||||
{
|
||||
static TypeId tid = TypeId ("ns3::TcpGymEnv")
|
||||
.SetParent<OpenGymEnv> ()
|
||||
.SetGroupName ("OpenGym")
|
||||
;
|
||||
|
||||
return tid;
|
||||
}
|
||||
|
||||
void
|
||||
TcpGymEnv::DoDispose ()
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
}
|
||||
|
||||
void
|
||||
TcpGymEnv::SetNodeId(uint32_t id)
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
m_nodeId = id;
|
||||
}
|
||||
|
||||
void
|
||||
TcpGymEnv::SetSocketUuid(uint32_t id)
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
m_socketUuid = id;
|
||||
}
|
||||
|
||||
std::string
|
||||
TcpGymEnv::GetTcpCongStateName(const TcpSocketState::TcpCongState_t state)
|
||||
{
|
||||
std::string stateName = "UNKNOWN";
|
||||
switch(state) {
|
||||
case TcpSocketState::CA_OPEN:
|
||||
stateName = "CA_OPEN";
|
||||
break;
|
||||
case TcpSocketState::CA_DISORDER:
|
||||
stateName = "CA_DISORDER";
|
||||
break;
|
||||
case TcpSocketState::CA_CWR:
|
||||
stateName = "CA_CWR";
|
||||
break;
|
||||
case TcpSocketState::CA_RECOVERY:
|
||||
stateName = "CA_RECOVERY";
|
||||
break;
|
||||
case TcpSocketState::CA_LOSS:
|
||||
stateName = "CA_LOSS";
|
||||
break;
|
||||
case TcpSocketState::CA_LAST_STATE:
|
||||
stateName = "CA_LAST_STATE";
|
||||
break;
|
||||
default:
|
||||
stateName = "UNKNOWN";
|
||||
break;
|
||||
}
|
||||
return stateName;
|
||||
}
|
||||
|
||||
std::string
|
||||
TcpGymEnv::GetTcpCAEventName(const TcpSocketState::TcpCAEvent_t event)
|
||||
{
|
||||
std::string eventName = "UNKNOWN";
|
||||
switch(event) {
|
||||
case TcpSocketState::CA_EVENT_TX_START:
|
||||
eventName = "CA_EVENT_TX_START";
|
||||
break;
|
||||
case TcpSocketState::CA_EVENT_CWND_RESTART:
|
||||
eventName = "CA_EVENT_CWND_RESTART";
|
||||
break;
|
||||
case TcpSocketState::CA_EVENT_COMPLETE_CWR:
|
||||
eventName = "CA_EVENT_COMPLETE_CWR";
|
||||
break;
|
||||
case TcpSocketState::CA_EVENT_LOSS:
|
||||
eventName = "CA_EVENT_LOSS";
|
||||
break;
|
||||
case TcpSocketState::CA_EVENT_ECN_NO_CE:
|
||||
eventName = "CA_EVENT_ECN_NO_CE";
|
||||
break;
|
||||
case TcpSocketState::CA_EVENT_ECN_IS_CE:
|
||||
eventName = "CA_EVENT_ECN_IS_CE";
|
||||
break;
|
||||
case TcpSocketState::CA_EVENT_DELAYED_ACK:
|
||||
eventName = "CA_EVENT_DELAYED_ACK";
|
||||
break;
|
||||
case TcpSocketState::CA_EVENT_NON_DELAYED_ACK:
|
||||
eventName = "CA_EVENT_NON_DELAYED_ACK";
|
||||
break;
|
||||
default:
|
||||
eventName = "UNKNOWN";
|
||||
break;
|
||||
}
|
||||
return eventName;
|
||||
}
|
||||
|
||||
/*
|
||||
Define action space
|
||||
*/
|
||||
Ptr<OpenGymSpace>
|
||||
TcpGymEnv::GetActionSpace()
|
||||
{
|
||||
// new_ssThresh
|
||||
// new_cWnd
|
||||
uint32_t parameterNum = 2;
|
||||
float low = 0.0;
|
||||
float high = 65535;
|
||||
std::vector<uint32_t> shape = {parameterNum,};
|
||||
std::string dtype = TypeNameGet<uint32_t> ();
|
||||
|
||||
Ptr<OpenGymBoxSpace> box = CreateObject<OpenGymBoxSpace> (low, high, shape, dtype);
|
||||
NS_LOG_INFO ("MyGetActionSpace: " << box);
|
||||
return box;
|
||||
}
|
||||
|
||||
/*
|
||||
Define game over condition
|
||||
*/
|
||||
bool
|
||||
TcpGymEnv::GetGameOver()
|
||||
{
|
||||
m_isGameOver = false;
|
||||
bool test = false;
|
||||
static float stepCounter = 0.0;
|
||||
stepCounter += 1;
|
||||
if (stepCounter == 10 && test) {
|
||||
m_isGameOver = true;
|
||||
}
|
||||
NS_LOG_INFO ("MyGetGameOver: " << m_isGameOver);
|
||||
return m_isGameOver;
|
||||
}
|
||||
|
||||
/*
|
||||
Define reward function
|
||||
*/
|
||||
float
|
||||
TcpGymEnv::GetReward()
|
||||
{
|
||||
NS_LOG_INFO("MyGetReward: " << m_envReward);
|
||||
return m_envReward;
|
||||
}
|
||||
|
||||
/*
|
||||
Define extra info. Optional
|
||||
*/
|
||||
std::string
|
||||
TcpGymEnv::GetExtraInfo()
|
||||
{
|
||||
NS_LOG_INFO("MyGetExtraInfo: " << m_info);
|
||||
return m_info;
|
||||
}
|
||||
|
||||
/*
|
||||
Execute received actions
|
||||
*/
|
||||
bool
|
||||
TcpGymEnv::ExecuteActions(Ptr<OpenGymDataContainer> action)
|
||||
{
|
||||
Ptr<OpenGymBoxContainer<uint32_t> > box = DynamicCast<OpenGymBoxContainer<uint32_t> >(action);
|
||||
m_new_ssThresh = box->GetValue(0);
|
||||
m_new_cWnd = box->GetValue(1);
|
||||
|
||||
NS_LOG_INFO ("MyExecuteActions: " << action);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
NS_OBJECT_ENSURE_REGISTERED (TcpEventGymEnv);
|
||||
|
||||
TcpEventGymEnv::TcpEventGymEnv () : TcpGymEnv()
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
}
|
||||
|
||||
TcpEventGymEnv::~TcpEventGymEnv ()
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
}
|
||||
|
||||
TypeId
|
||||
TcpEventGymEnv::GetTypeId (void)
|
||||
{
|
||||
static TypeId tid = TypeId ("ns3::TcpEventGymEnv")
|
||||
.SetParent<TcpGymEnv> ()
|
||||
.SetGroupName ("OpenGym")
|
||||
.AddConstructor<TcpEventGymEnv> ()
|
||||
;
|
||||
|
||||
return tid;
|
||||
}
|
||||
|
||||
void
|
||||
TcpEventGymEnv::DoDispose ()
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
}
|
||||
|
||||
void
|
||||
TcpEventGymEnv::SetReward(float value)
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
m_reward = value;
|
||||
}
|
||||
|
||||
void
|
||||
TcpEventGymEnv::SetPenalty(float value)
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
m_penalty = value;
|
||||
}
|
||||
|
||||
/*
|
||||
Define observation space
|
||||
*/
|
||||
Ptr<OpenGymSpace>
|
||||
TcpEventGymEnv::GetObservationSpace()
|
||||
{
|
||||
// socket unique ID
|
||||
// tcp env type: event-based = 0 / time-based = 1
|
||||
// sim time in us
|
||||
// node ID
|
||||
// ssThresh
|
||||
// cWnd
|
||||
// segmentSize
|
||||
// segmentsAcked
|
||||
// bytesInFlight
|
||||
// rtt in us
|
||||
// min rtt in us
|
||||
// called func
|
||||
// congestion algorithm (CA) state
|
||||
// CA event
|
||||
// ECN state
|
||||
uint32_t parameterNum = 10;
|
||||
float low = 0.0;
|
||||
float high = 1000000000.0;
|
||||
std::vector<uint32_t> shape = {parameterNum,};
|
||||
std::string dtype = TypeNameGet<uint64_t> ();
|
||||
|
||||
Ptr<OpenGymBoxSpace> box = CreateObject<OpenGymBoxSpace> (low, high, shape, dtype);
|
||||
NS_LOG_INFO ("MyGetObservationSpace: " << box);
|
||||
return box;
|
||||
}
|
||||
|
||||
/*
|
||||
Collect observations
|
||||
*/
|
||||
Ptr<OpenGymDataContainer>
|
||||
TcpEventGymEnv::GetObservation()
|
||||
{
|
||||
uint32_t parameterNum = 10;
|
||||
std::vector<uint32_t> shape = {parameterNum,};
|
||||
|
||||
Ptr<OpenGymBoxContainer<uint64_t> > box = CreateObject<OpenGymBoxContainer<uint64_t> >(shape);
|
||||
|
||||
box->AddValue(m_socketUuid);
|
||||
box->AddValue(0);
|
||||
box->AddValue(Simulator::Now().GetMicroSeconds ());
|
||||
box->AddValue(m_nodeId);
|
||||
box->AddValue(m_tcb->m_ssThresh);
|
||||
box->AddValue(m_tcb->m_cWnd);
|
||||
box->AddValue(m_tcb->m_segmentSize);
|
||||
box->AddValue(m_segmentsAcked);
|
||||
box->AddValue(m_bytesInFlight);
|
||||
box->AddValue(m_rtt.GetMicroSeconds ());
|
||||
//box->AddValue(m_tcb->m_minRtt.GetMicroSeconds ());
|
||||
//box->AddValue(m_calledFunc);
|
||||
//box->AddValue(m_tcb->m_congState);
|
||||
//box->AddValue(m_event);
|
||||
//box->AddValue(m_tcb->m_ecnState);
|
||||
|
||||
// Print data
|
||||
NS_LOG_INFO ("MyGetObservation: " << box);
|
||||
return box;
|
||||
}
|
||||
|
||||
void
|
||||
TcpEventGymEnv::TxPktTrace(Ptr<const Packet>, const TcpHeader&, Ptr<const TcpSocketBase>)
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
}
|
||||
|
||||
void
|
||||
TcpEventGymEnv::RxPktTrace(Ptr<const Packet>, const TcpHeader&, Ptr<const TcpSocketBase>)
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
}
|
||||
|
||||
uint32_t
|
||||
TcpEventGymEnv::GetSsThresh (Ptr<const TcpSocketState> tcb, uint32_t bytesInFlight)
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
// pkt was lost, so penalty
|
||||
m_envReward = m_penalty;
|
||||
|
||||
NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " GetSsThresh, BytesInFlight: " << bytesInFlight);
|
||||
m_calledFunc = CalledFunc_t::GET_SS_THRESH;
|
||||
m_info = "GetSsThresh";
|
||||
m_tcb = tcb;
|
||||
m_bytesInFlight = bytesInFlight;
|
||||
Notify();
|
||||
return m_new_ssThresh;
|
||||
}
|
||||
|
||||
void
|
||||
TcpEventGymEnv::IncreaseWindow (Ptr<TcpSocketState> tcb, uint32_t segmentsAcked)
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
// pkt was acked, so reward
|
||||
m_envReward = m_reward;
|
||||
|
||||
NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " IncreaseWindow, SegmentsAcked: " << segmentsAcked);
|
||||
m_calledFunc = CalledFunc_t::INCREASE_WINDOW;
|
||||
m_info = "IncreaseWindow";
|
||||
m_tcb = tcb;
|
||||
m_segmentsAcked = segmentsAcked;
|
||||
Notify();
|
||||
tcb->m_cWnd = m_new_cWnd;
|
||||
}
|
||||
|
||||
void
|
||||
TcpEventGymEnv::PktsAcked (Ptr<TcpSocketState> tcb, uint32_t segmentsAcked, const Time& rtt)
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " PktsAcked, SegmentsAcked: " << segmentsAcked << " Rtt: " << rtt);
|
||||
m_calledFunc = CalledFunc_t::PKTS_ACKED;
|
||||
m_info = "PktsAcked";
|
||||
m_tcb = tcb;
|
||||
m_segmentsAcked = segmentsAcked;
|
||||
m_rtt = rtt;
|
||||
}
|
||||
|
||||
void
|
||||
TcpEventGymEnv::CongestionStateSet (Ptr<TcpSocketState> tcb, const TcpSocketState::TcpCongState_t newState)
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
std::string stateName = GetTcpCongStateName(newState);
|
||||
NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " CongestionStateSet: " << newState << " " << stateName);
|
||||
|
||||
m_calledFunc = CalledFunc_t::CONGESTION_STATE_SET;
|
||||
m_info = "CongestionStateSet";
|
||||
m_tcb = tcb;
|
||||
m_newState = newState;
|
||||
}
|
||||
|
||||
void
|
||||
TcpEventGymEnv::CwndEvent (Ptr<TcpSocketState> tcb, const TcpSocketState::TcpCAEvent_t event)
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
std::string eventName = GetTcpCAEventName(event);
|
||||
NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " CwndEvent: " << event << " " << eventName);
|
||||
|
||||
m_calledFunc = CalledFunc_t::CWND_EVENT;
|
||||
m_info = "CwndEvent";
|
||||
m_tcb = tcb;
|
||||
m_event = event;
|
||||
}
|
||||
|
||||
|
||||
NS_OBJECT_ENSURE_REGISTERED (TcpTimeStepGymEnv);
|
||||
|
||||
TcpTimeStepGymEnv::TcpTimeStepGymEnv () : TcpGymEnv()
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
}
|
||||
|
||||
void
|
||||
TcpTimeStepGymEnv::ScheduleNextStateRead ()
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
Simulator::Schedule (m_timeStep, &TcpTimeStepGymEnv::ScheduleNextStateRead, this);
|
||||
Notify();
|
||||
}
|
||||
|
||||
TcpTimeStepGymEnv::~TcpTimeStepGymEnv ()
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
}
|
||||
|
||||
TypeId
|
||||
TcpTimeStepGymEnv::GetTypeId (void)
|
||||
{
|
||||
static TypeId tid = TypeId ("ns3::TcpTimeStepGymEnv")
|
||||
.SetParent<TcpGymEnv> ()
|
||||
.SetGroupName ("OpenGym")
|
||||
.AddConstructor<TcpTimeStepGymEnv> ()
|
||||
;
|
||||
|
||||
return tid;
|
||||
}
|
||||
|
||||
void
|
||||
TcpTimeStepGymEnv::DoDispose ()
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
}
|
||||
|
||||
void
|
||||
TcpTimeStepGymEnv::SetDuration(Time value)
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
m_duration = value;
|
||||
}
|
||||
|
||||
void
|
||||
TcpTimeStepGymEnv::SetTimeStep(Time value)
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
m_timeStep = value;
|
||||
}
|
||||
|
||||
void
|
||||
TcpTimeStepGymEnv::SetReward(float value)
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
m_reward = value;
|
||||
}
|
||||
|
||||
void
|
||||
TcpTimeStepGymEnv::SetPenalty(float value)
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
m_penalty = value;
|
||||
}
|
||||
|
||||
/*
|
||||
Define observation space
|
||||
*/
|
||||
Ptr<OpenGymSpace>
|
||||
TcpTimeStepGymEnv::GetObservationSpace()
|
||||
{
|
||||
// socket unique ID
|
||||
// tcp env type: event-based = 0 / time-based = 1
|
||||
// sim time in us
|
||||
// node ID
|
||||
// ssThresh
|
||||
// cWnd
|
||||
// segmentSize
|
||||
// bytesInFlightSum
|
||||
// bytesInFlightAvg
|
||||
// segmentsAckedSum
|
||||
// segmentsAckedAvg
|
||||
// avgRtt
|
||||
// minRtt
|
||||
// avgInterTx
|
||||
// avgInterRx
|
||||
// throughput
|
||||
uint32_t parameterNum = 16;
|
||||
float low = 0.0;
|
||||
float high = 1000000000.0;
|
||||
std::vector<uint32_t> shape = {parameterNum,};
|
||||
std::string dtype = TypeNameGet<uint64_t> ();
|
||||
|
||||
Ptr<OpenGymBoxSpace> box = CreateObject<OpenGymBoxSpace> (low, high, shape, dtype);
|
||||
NS_LOG_INFO ("MyGetObservationSpace: " << box);
|
||||
return box;
|
||||
}
|
||||
|
||||
/*
|
||||
Collect observations
|
||||
*/
|
||||
Ptr<OpenGymDataContainer>
|
||||
TcpTimeStepGymEnv::GetObservation()
|
||||
{
|
||||
uint32_t parameterNum = 16;
|
||||
std::vector<uint32_t> shape = {parameterNum,};
|
||||
|
||||
Ptr<OpenGymBoxContainer<uint64_t> > box = CreateObject<OpenGymBoxContainer<uint64_t> >(shape);
|
||||
|
||||
box->AddValue(m_socketUuid);
|
||||
box->AddValue(1);
|
||||
box->AddValue(Simulator::Now().GetMicroSeconds ());
|
||||
box->AddValue(m_nodeId);
|
||||
box->AddValue(m_tcb->m_ssThresh);
|
||||
box->AddValue(m_tcb->m_cWnd);
|
||||
box->AddValue(m_tcb->m_segmentSize);
|
||||
|
||||
//bytesInFlightSum
|
||||
uint64_t bytesInFlightSum = std::accumulate(m_bytesInFlight.begin(), m_bytesInFlight.end(), 0);
|
||||
box->AddValue(bytesInFlightSum);
|
||||
|
||||
//bytesInFlightAvg
|
||||
uint64_t bytesInFlightAvg = 0;
|
||||
if (m_bytesInFlight.size()) {
|
||||
bytesInFlightAvg = bytesInFlightSum / m_bytesInFlight.size();
|
||||
}
|
||||
box->AddValue(bytesInFlightAvg);
|
||||
|
||||
//segmentsAckedSum
|
||||
uint64_t segmentsAckedSum = std::accumulate(m_segmentsAcked.begin(), m_segmentsAcked.end(), 0);
|
||||
box->AddValue(segmentsAckedSum);
|
||||
|
||||
//segmentsAckedAvg
|
||||
uint64_t segmentsAckedAvg = 0;
|
||||
if (m_segmentsAcked.size()) {
|
||||
segmentsAckedAvg = segmentsAckedSum / m_segmentsAcked.size();
|
||||
}
|
||||
box->AddValue(segmentsAckedAvg);
|
||||
|
||||
//avgRtt
|
||||
Time avgRtt = Seconds(0.0);
|
||||
if(m_rttSampleNum) {
|
||||
avgRtt = m_rttSum / m_rttSampleNum;
|
||||
}
|
||||
box->AddValue(avgRtt.GetMicroSeconds ());
|
||||
|
||||
/*---------------------------------------------------------------------------------------------------*/
|
||||
/*---------------------------------------------------------------------------------------------------*/
|
||||
/*---------------------------------------------------------------------------------------------------*/
|
||||
|
||||
// Update reward based on overall average of avgRtt over all steps so far
|
||||
// only when agent increases cWnd
|
||||
// TODO: this is not the right way of doing this.
|
||||
// place this somewhere else. see TcpEventGymEnv, how they've done it.
|
||||
|
||||
if (m_new_cWnd > m_old_cWnd && m_totalAvgRttSum > 0 && avgRtt > 0) {
|
||||
// when agent increases cWnd
|
||||
if ((m_totalAvgRttSum / m_totalAvgRttNum) >= avgRtt) {
|
||||
// give reward for decreasing avgRtt
|
||||
m_envReward = m_reward;
|
||||
} else {
|
||||
// give penalty for increasing avgRtt
|
||||
m_envReward = m_penalty;
|
||||
}
|
||||
} else {
|
||||
// agent has not increased cWnd
|
||||
m_envReward = 0;
|
||||
}
|
||||
|
||||
// Update m_totalAvgRtSum and m_totalAvgRttNum
|
||||
m_totalAvgRttSum += avgRtt;
|
||||
m_totalAvgRttNum++;
|
||||
|
||||
m_old_cWnd = m_new_cWnd;
|
||||
/*---------------------------------------------------------------------------------------------------*/
|
||||
/*---------------------------------------------------------------------------------------------------*/
|
||||
/*---------------------------------------------------------------------------------------------------*/
|
||||
|
||||
//m_minRtt
|
||||
box->AddValue(m_tcb->m_minRtt.GetMicroSeconds ());
|
||||
|
||||
//avgInterTx
|
||||
Time avgInterTx = Seconds(0.0);
|
||||
if (m_interTxTimeNum) {
|
||||
avgInterTx = m_interTxTimeSum / m_interTxTimeNum;
|
||||
}
|
||||
box->AddValue(avgInterTx.GetMicroSeconds ());
|
||||
|
||||
//avgInterRx
|
||||
Time avgInterRx = Seconds(0.0);
|
||||
if (m_interRxTimeNum) {
|
||||
avgInterRx = m_interRxTimeSum / m_interRxTimeNum;
|
||||
}
|
||||
box->AddValue(avgInterRx.GetMicroSeconds ());
|
||||
|
||||
//throughput bytes/s
|
||||
float throughput = (segmentsAckedSum * m_tcb->m_segmentSize) / m_timeStep.GetSeconds();
|
||||
box->AddValue(throughput);
|
||||
|
||||
// Print data
|
||||
NS_LOG_INFO ("MyGetObservation: " << box);
|
||||
|
||||
m_bytesInFlight.clear();
|
||||
m_segmentsAcked.clear();
|
||||
|
||||
m_rttSampleNum = 0;
|
||||
m_rttSum = MicroSeconds (0.0);
|
||||
|
||||
m_interTxTimeNum = 0;
|
||||
m_interTxTimeSum = MicroSeconds (0.0);
|
||||
|
||||
m_interRxTimeNum = 0;
|
||||
m_interRxTimeSum = MicroSeconds (0.0);
|
||||
|
||||
return box;
|
||||
}
|
||||
|
||||
void
|
||||
TcpTimeStepGymEnv::TxPktTrace(Ptr<const Packet>, const TcpHeader&, Ptr<const TcpSocketBase>)
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
if ( m_lastPktTxTime > MicroSeconds(0.0) ) {
|
||||
Time interTxTime = Simulator::Now() - m_lastPktTxTime;
|
||||
m_interTxTimeSum += interTxTime;
|
||||
m_interTxTimeNum++;
|
||||
}
|
||||
|
||||
m_lastPktTxTime = Simulator::Now();
|
||||
}
|
||||
|
||||
void
|
||||
TcpTimeStepGymEnv::RxPktTrace(Ptr<const Packet>, const TcpHeader&, Ptr<const TcpSocketBase>)
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
if ( m_lastPktRxTime > MicroSeconds(0.0) ) {
|
||||
Time interRxTime = Simulator::Now() - m_lastPktRxTime;
|
||||
m_interRxTimeSum += interRxTime;
|
||||
m_interRxTimeNum++;
|
||||
}
|
||||
|
||||
m_lastPktRxTime = Simulator::Now();
|
||||
}
|
||||
|
||||
uint32_t
|
||||
TcpTimeStepGymEnv::GetSsThresh (Ptr<const TcpSocketState> tcb, uint32_t bytesInFlight)
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " GetSsThresh, BytesInFlight: " << bytesInFlight);
|
||||
m_tcb = tcb;
|
||||
m_bytesInFlight.push_back(bytesInFlight);
|
||||
|
||||
if (!m_started) {
|
||||
m_started = true;
|
||||
Notify();
|
||||
ScheduleNextStateRead();
|
||||
}
|
||||
|
||||
// action
|
||||
return m_new_ssThresh;
|
||||
}
|
||||
|
||||
void
|
||||
TcpTimeStepGymEnv::IncreaseWindow (Ptr<TcpSocketState> tcb, uint32_t segmentsAcked)
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " IncreaseWindow, SegmentsAcked: " << segmentsAcked);
|
||||
m_tcb = tcb;
|
||||
m_segmentsAcked.push_back(segmentsAcked);
|
||||
m_bytesInFlight.push_back(tcb->m_bytesInFlight);
|
||||
|
||||
if (!m_started) {
|
||||
m_started = true;
|
||||
Notify();
|
||||
ScheduleNextStateRead();
|
||||
}
|
||||
// action
|
||||
tcb->m_cWnd = m_new_cWnd;
|
||||
}
|
||||
|
||||
void
|
||||
TcpTimeStepGymEnv::PktsAcked (Ptr<TcpSocketState> tcb, uint32_t segmentsAcked, const Time& rtt)
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " PktsAcked, SegmentsAcked: " << segmentsAcked << " Rtt: " << rtt);
|
||||
m_tcb = tcb;
|
||||
m_rttSum += rtt;
|
||||
m_rttSampleNum++;
|
||||
}
|
||||
|
||||
void
|
||||
TcpTimeStepGymEnv::CongestionStateSet (Ptr<TcpSocketState> tcb, const TcpSocketState::TcpCongState_t newState)
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
std::string stateName = GetTcpCongStateName(newState);
|
||||
NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " CongestionStateSet: " << newState << " " << stateName);
|
||||
m_tcb = tcb;
|
||||
}
|
||||
|
||||
void
|
||||
TcpTimeStepGymEnv::CwndEvent (Ptr<TcpSocketState> tcb, const TcpSocketState::TcpCAEvent_t event)
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
std::string eventName = GetTcpCAEventName(event);
|
||||
NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " CwndEvent: " << event << " " << eventName);
|
||||
}
|
||||
|
||||
} // namespace ns3
|
||||
210
tcp-rl-env.h
Normal file
210
tcp-rl-env.h
Normal file
@ -0,0 +1,210 @@
|
||||
/* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */
|
||||
/*
|
||||
* Copyright (c) 2018 Technische Universität Berlin
|
||||
*
|
||||
* This program is free software; you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License version 2 as
|
||||
* published by the Free Software Foundation;
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program; if not, write to the Free Software
|
||||
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
|
||||
*
|
||||
* Author: Piotr Gawlowicz <gawlowicz@tkn.tu-berlin.de>
|
||||
*/
|
||||
|
||||
#ifndef TCP_RL_ENV_H
|
||||
#define TCP_RL_ENV_H
|
||||
|
||||
#include "ns3/opengym-module.h"
|
||||
#include "ns3/tcp-socket-base.h"
|
||||
#include <vector>
|
||||
|
||||
namespace ns3 {
|
||||
|
||||
class Packet;
|
||||
class TcpHeader;
|
||||
class TcpSocketBase;
|
||||
class Time;
|
||||
|
||||
|
||||
class TcpGymEnv : public OpenGymEnv
|
||||
{
|
||||
public:
|
||||
TcpGymEnv ();
|
||||
virtual ~TcpGymEnv ();
|
||||
static TypeId GetTypeId (void);
|
||||
virtual void DoDispose ();
|
||||
|
||||
void SetNodeId(uint32_t id);
|
||||
void SetSocketUuid(uint32_t id);
|
||||
|
||||
std::string GetTcpCongStateName(const TcpSocketState::TcpCongState_t state);
|
||||
std::string GetTcpCAEventName(const TcpSocketState::TcpCAEvent_t event);
|
||||
|
||||
// OpenGym interface
|
||||
virtual Ptr<OpenGymSpace> GetActionSpace();
|
||||
virtual bool GetGameOver();
|
||||
virtual float GetReward();
|
||||
virtual std::string GetExtraInfo();
|
||||
virtual bool ExecuteActions(Ptr<OpenGymDataContainer> action);
|
||||
|
||||
virtual Ptr<OpenGymSpace> GetObservationSpace() = 0;
|
||||
virtual Ptr<OpenGymDataContainer> GetObservation() = 0;
|
||||
|
||||
// trace packets, e.g. for calculating inter tx/rx time
|
||||
virtual void TxPktTrace(Ptr<const Packet>, const TcpHeader&, Ptr<const TcpSocketBase>) = 0;
|
||||
virtual void RxPktTrace(Ptr<const Packet>, const TcpHeader&, Ptr<const TcpSocketBase>) = 0;
|
||||
|
||||
// TCP congestion control interface
|
||||
virtual uint32_t GetSsThresh (Ptr<const TcpSocketState> tcb, uint32_t bytesInFlight) = 0;
|
||||
virtual void IncreaseWindow (Ptr<TcpSocketState> tcb, uint32_t segmentsAcked) = 0;
|
||||
// optional functions used to collect obs
|
||||
virtual void PktsAcked (Ptr<TcpSocketState> tcb, uint32_t segmentsAcked, const Time& rtt) = 0;
|
||||
virtual void CongestionStateSet (Ptr<TcpSocketState> tcb, const TcpSocketState::TcpCongState_t newState) = 0;
|
||||
virtual void CwndEvent (Ptr<TcpSocketState> tcb, const TcpSocketState::TcpCAEvent_t event) = 0;
|
||||
|
||||
typedef enum
|
||||
{
|
||||
GET_SS_THRESH = 0,
|
||||
INCREASE_WINDOW,
|
||||
PKTS_ACKED,
|
||||
CONGESTION_STATE_SET,
|
||||
CWND_EVENT,
|
||||
} CalledFunc_t;
|
||||
|
||||
protected:
|
||||
uint32_t m_nodeId;
|
||||
uint32_t m_socketUuid;
|
||||
|
||||
// state
|
||||
// obs has to be implemented in child class
|
||||
|
||||
// game over
|
||||
bool m_isGameOver;
|
||||
|
||||
// reward
|
||||
float m_envReward;
|
||||
|
||||
// extra info
|
||||
std::string m_info;
|
||||
|
||||
// actions
|
||||
uint32_t m_new_ssThresh;
|
||||
uint32_t m_new_cWnd;
|
||||
};
|
||||
|
||||
|
||||
class TcpEventGymEnv : public TcpGymEnv
|
||||
{
|
||||
public:
|
||||
TcpEventGymEnv ();
|
||||
virtual ~TcpEventGymEnv ();
|
||||
static TypeId GetTypeId (void);
|
||||
virtual void DoDispose ();
|
||||
|
||||
void SetReward(float value);
|
||||
void SetPenalty(float value);
|
||||
|
||||
// OpenGym interface
|
||||
virtual Ptr<OpenGymSpace> GetObservationSpace();
|
||||
Ptr<OpenGymDataContainer> GetObservation();
|
||||
|
||||
// trace packets, e.g. for calculating inter tx/rx time
|
||||
virtual void TxPktTrace(Ptr<const Packet>, const TcpHeader&, Ptr<const TcpSocketBase>);
|
||||
virtual void RxPktTrace(Ptr<const Packet>, const TcpHeader&, Ptr<const TcpSocketBase>);
|
||||
|
||||
// TCP congestion control interface
|
||||
virtual uint32_t GetSsThresh (Ptr<const TcpSocketState> tcb, uint32_t bytesInFlight);
|
||||
virtual void IncreaseWindow (Ptr<TcpSocketState> tcb, uint32_t segmentsAcked);
|
||||
// optional functions used to collect obs
|
||||
virtual void PktsAcked (Ptr<TcpSocketState> tcb, uint32_t segmentsAcked, const Time& rtt);
|
||||
virtual void CongestionStateSet (Ptr<TcpSocketState> tcb, const TcpSocketState::TcpCongState_t newState);
|
||||
virtual void CwndEvent (Ptr<TcpSocketState> tcb, const TcpSocketState::TcpCAEvent_t event);
|
||||
|
||||
private:
|
||||
// state
|
||||
CalledFunc_t m_calledFunc;
|
||||
Ptr<const TcpSocketState> m_tcb;
|
||||
uint32_t m_bytesInFlight;
|
||||
uint32_t m_segmentsAcked;
|
||||
Time m_rtt;
|
||||
TcpSocketState::TcpCongState_t m_newState;
|
||||
TcpSocketState::TcpCAEvent_t m_event;
|
||||
|
||||
// reward
|
||||
float m_reward;
|
||||
float m_penalty;
|
||||
};
|
||||
|
||||
|
||||
class TcpTimeStepGymEnv : public TcpGymEnv
|
||||
{
|
||||
public:
|
||||
TcpTimeStepGymEnv ();
|
||||
|
||||
virtual ~TcpTimeStepGymEnv ();
|
||||
static TypeId GetTypeId (void);
|
||||
virtual void DoDispose ();
|
||||
|
||||
void SetDuration(Time value);
|
||||
void SetTimeStep(Time value);
|
||||
void SetReward(float value);
|
||||
void SetPenalty(float value);
|
||||
|
||||
// OpenGym interface
|
||||
virtual Ptr<OpenGymSpace> GetObservationSpace();
|
||||
Ptr<OpenGymDataContainer> GetObservation();
|
||||
|
||||
// trace packets, e.g. for calculating inter tx/rx time
|
||||
virtual void TxPktTrace(Ptr<const Packet>, const TcpHeader&, Ptr<const TcpSocketBase>);
|
||||
virtual void RxPktTrace(Ptr<const Packet>, const TcpHeader&, Ptr<const TcpSocketBase>);
|
||||
|
||||
// TCP congestion control interface
|
||||
virtual uint32_t GetSsThresh (Ptr<const TcpSocketState> tcb, uint32_t bytesInFlight);
|
||||
virtual void IncreaseWindow (Ptr<TcpSocketState> tcb, uint32_t segmentsAcked);
|
||||
// optional functions used to collect obs
|
||||
virtual void PktsAcked (Ptr<TcpSocketState> tcb, uint32_t segmentsAcked, const Time& rtt);
|
||||
virtual void CongestionStateSet (Ptr<TcpSocketState> tcb, const TcpSocketState::TcpCongState_t newState);
|
||||
virtual void CwndEvent (Ptr<TcpSocketState> tcb, const TcpSocketState::TcpCAEvent_t event);
|
||||
|
||||
private:
|
||||
void ScheduleNextStateRead();
|
||||
bool m_started {false};
|
||||
Time m_duration;
|
||||
Time m_timeStep;
|
||||
|
||||
// state
|
||||
Ptr<const TcpSocketState> m_tcb;
|
||||
std::vector<uint32_t> m_bytesInFlight;
|
||||
std::vector<uint32_t> m_segmentsAcked;
|
||||
|
||||
uint64_t m_rttSampleNum {0};
|
||||
Time m_rttSum {MicroSeconds (0.0)};
|
||||
|
||||
Time m_lastPktTxTime {MicroSeconds(0.0)};
|
||||
Time m_lastPktRxTime {MicroSeconds(0.0)};
|
||||
uint64_t m_interTxTimeNum {0};
|
||||
Time m_interTxTimeSum {MicroSeconds (0.0)};
|
||||
uint64_t m_interRxTimeNum {0};
|
||||
Time m_interRxTimeSum {MicroSeconds (0.0)};
|
||||
Time m_prevAvgRtt {MicroSeconds (0.0)};
|
||||
Time m_totalAvgRttSum {MicroSeconds (0.0)};
|
||||
uint64_t m_totalAvgRttNum {0};
|
||||
uint32_t m_old_cWnd {0};
|
||||
|
||||
// reward
|
||||
float m_reward;
|
||||
float m_penalty;
|
||||
};
|
||||
|
||||
|
||||
|
||||
} // namespace ns3
|
||||
|
||||
#endif /* TCP_RL_ENV_H */
|
||||
382
tcp-rl.cc
Normal file
382
tcp-rl.cc
Normal file
@ -0,0 +1,382 @@
|
||||
/* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */
|
||||
/*
|
||||
* Copyright (c) 2018 Technische Universität Berlin
|
||||
*
|
||||
* This program is free software; you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License version 2 as
|
||||
* published by the Free Software Foundation;
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program; if not, write to the Free Software
|
||||
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
|
||||
*
|
||||
* Author: Piotr Gawlowicz <gawlowicz@tkn.tu-berlin.de>
|
||||
*/
|
||||
|
||||
#include "tcp-rl.h"
|
||||
#include "tcp-rl-env.h"
|
||||
#include "ns3/tcp-header.h"
|
||||
#include "ns3/object.h"
|
||||
#include "ns3/node-list.h"
|
||||
#include "ns3/core-module.h"
|
||||
#include "ns3/log.h"
|
||||
#include "ns3/simulator.h"
|
||||
#include "ns3/tcp-socket-base.h"
|
||||
#include "ns3/tcp-l4-protocol.h"
|
||||
|
||||
|
||||
namespace ns3 {
|
||||
|
||||
|
||||
NS_OBJECT_ENSURE_REGISTERED (TcpSocketDerived);
|
||||
|
||||
TypeId
|
||||
TcpSocketDerived::GetTypeId (void)
|
||||
{
|
||||
static TypeId tid = TypeId ("ns3::TcpSocketDerived")
|
||||
.SetParent<TcpSocketBase> ()
|
||||
.SetGroupName ("Internet")
|
||||
.AddConstructor<TcpSocketDerived> ()
|
||||
;
|
||||
return tid;
|
||||
}
|
||||
|
||||
TypeId
|
||||
TcpSocketDerived::GetInstanceTypeId () const
|
||||
{
|
||||
return TcpSocketDerived::GetTypeId ();
|
||||
}
|
||||
|
||||
TcpSocketDerived::TcpSocketDerived (void)
|
||||
{
|
||||
}
|
||||
|
||||
Ptr<TcpCongestionOps>
|
||||
TcpSocketDerived::GetCongestionControlAlgorithm ()
|
||||
{
|
||||
return m_congestionControl;
|
||||
}
|
||||
|
||||
TcpSocketDerived::~TcpSocketDerived (void)
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
NS_LOG_COMPONENT_DEFINE ("ns3::TcpRlBase");
|
||||
NS_OBJECT_ENSURE_REGISTERED (TcpRlBase);
|
||||
|
||||
TypeId
|
||||
TcpRlBase::GetTypeId (void)
|
||||
{
|
||||
static TypeId tid = TypeId ("ns3::TcpRlBase")
|
||||
.SetParent<TcpCongestionOps> ()
|
||||
.SetGroupName ("Internet")
|
||||
.AddConstructor<TcpRlBase> ()
|
||||
;
|
||||
return tid;
|
||||
}
|
||||
|
||||
TcpRlBase::TcpRlBase (void)
|
||||
: TcpCongestionOps ()
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
m_tcpSocket = 0;
|
||||
m_tcpGymEnv = 0;
|
||||
}
|
||||
|
||||
TcpRlBase::TcpRlBase (const TcpRlBase& sock)
|
||||
: TcpCongestionOps (sock)
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
m_tcpSocket = 0;
|
||||
m_tcpGymEnv = 0;
|
||||
}
|
||||
|
||||
TcpRlBase::~TcpRlBase (void)
|
||||
{
|
||||
m_tcpSocket = 0;
|
||||
m_tcpGymEnv = 0;
|
||||
}
|
||||
|
||||
uint64_t
|
||||
TcpRlBase::GenerateUuid ()
|
||||
{
|
||||
static uint64_t uuid = 0;
|
||||
uuid++;
|
||||
return uuid;
|
||||
}
|
||||
|
||||
void
|
||||
TcpRlBase::CreateGymEnv()
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
// should never be called, only child classes: TcpRl and TcpRlTimeBased
|
||||
}
|
||||
|
||||
void
|
||||
TcpRlBase::ConnectSocketCallbacks()
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
|
||||
bool foundSocket = false;
|
||||
for (NodeList::Iterator i = NodeList::Begin (); i != NodeList::End (); ++i) {
|
||||
Ptr<Node> node = *i;
|
||||
Ptr<TcpL4Protocol> tcp = node->GetObject<TcpL4Protocol> ();
|
||||
|
||||
ObjectVectorValue socketVec;
|
||||
tcp->GetAttribute ("SocketList", socketVec);
|
||||
NS_LOG_DEBUG("Node: " << node->GetId() << " TCP socket num: " << socketVec.GetN());
|
||||
|
||||
uint32_t sockNum = socketVec.GetN();
|
||||
for (uint32_t j=0; j<sockNum; j++) {
|
||||
Ptr<Object> sockObj = socketVec.Get(j);
|
||||
Ptr<TcpSocketBase> tcpSocket = DynamicCast<TcpSocketBase> (sockObj);
|
||||
NS_LOG_DEBUG("Node: " << node->GetId() << " TCP Socket: " << tcpSocket);
|
||||
if(!tcpSocket) { continue; }
|
||||
|
||||
Ptr<TcpSocketDerived> dtcpSocket = StaticCast<TcpSocketDerived>(tcpSocket);
|
||||
Ptr<TcpCongestionOps> ca = dtcpSocket->GetCongestionControlAlgorithm();
|
||||
NS_LOG_DEBUG("CA name: " << ca->GetName());
|
||||
Ptr<TcpRlBase> rlCa = DynamicCast<TcpRlBase>(ca);
|
||||
if (rlCa == this) {
|
||||
NS_LOG_DEBUG("Found TcpRl CA!");
|
||||
foundSocket = true;
|
||||
m_tcpSocket = tcpSocket;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (foundSocket) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
NS_ASSERT_MSG(m_tcpSocket, "TCP socket was not found.");
|
||||
|
||||
if(m_tcpSocket) {
|
||||
NS_LOG_DEBUG("Found TCP Socket: " << m_tcpSocket);
|
||||
m_tcpSocket->TraceConnectWithoutContext ("Tx", MakeCallback (&TcpGymEnv::TxPktTrace, m_tcpGymEnv));
|
||||
m_tcpSocket->TraceConnectWithoutContext ("Rx", MakeCallback (&TcpGymEnv::RxPktTrace, m_tcpGymEnv));
|
||||
NS_LOG_DEBUG("Connect socket callbacks " << m_tcpSocket->GetNode()->GetId());
|
||||
m_tcpGymEnv->SetNodeId(m_tcpSocket->GetNode()->GetId());
|
||||
}
|
||||
}
|
||||
|
||||
std::string
|
||||
TcpRlBase::GetName () const
|
||||
{
|
||||
return "TcpRlBase";
|
||||
}
|
||||
|
||||
uint32_t
|
||||
TcpRlBase::GetSsThresh (Ptr<const TcpSocketState> state,
|
||||
uint32_t bytesInFlight)
|
||||
{
|
||||
NS_LOG_FUNCTION (this << state << bytesInFlight);
|
||||
|
||||
if (!m_tcpGymEnv) {
|
||||
CreateGymEnv();
|
||||
}
|
||||
|
||||
uint32_t newSsThresh = 0;
|
||||
if (m_tcpGymEnv) {
|
||||
newSsThresh = m_tcpGymEnv->GetSsThresh(state, bytesInFlight);
|
||||
}
|
||||
|
||||
return newSsThresh;
|
||||
}
|
||||
|
||||
void
|
||||
TcpRlBase::IncreaseWindow (Ptr<TcpSocketState> tcb, uint32_t segmentsAcked)
|
||||
{
|
||||
NS_LOG_FUNCTION (this << tcb << segmentsAcked);
|
||||
|
||||
if (!m_tcpGymEnv) {
|
||||
CreateGymEnv();
|
||||
}
|
||||
|
||||
if (m_tcpGymEnv) {
|
||||
m_tcpGymEnv->IncreaseWindow(tcb, segmentsAcked);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
TcpRlBase::PktsAcked (Ptr<TcpSocketState> tcb, uint32_t segmentsAcked, const Time& rtt)
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
|
||||
if (!m_tcpGymEnv) {
|
||||
CreateGymEnv();
|
||||
}
|
||||
|
||||
if (m_tcpGymEnv) {
|
||||
m_tcpGymEnv->PktsAcked(tcb, segmentsAcked, rtt);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
TcpRlBase::CongestionStateSet (Ptr<TcpSocketState> tcb, const TcpSocketState::TcpCongState_t newState)
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
|
||||
if (!m_tcpGymEnv) {
|
||||
CreateGymEnv();
|
||||
}
|
||||
|
||||
if (m_tcpGymEnv) {
|
||||
m_tcpGymEnv->CongestionStateSet(tcb, newState);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
TcpRlBase::CwndEvent (Ptr<TcpSocketState> tcb, const TcpSocketState::TcpCAEvent_t event)
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
|
||||
if (!m_tcpGymEnv) {
|
||||
CreateGymEnv();
|
||||
}
|
||||
|
||||
if (m_tcpGymEnv) {
|
||||
m_tcpGymEnv->CwndEvent(tcb, event);
|
||||
}
|
||||
}
|
||||
|
||||
Ptr<TcpCongestionOps>
|
||||
TcpRlBase::Fork ()
|
||||
{
|
||||
return CopyObject<TcpRlBase> (this);
|
||||
}
|
||||
|
||||
|
||||
NS_OBJECT_ENSURE_REGISTERED (TcpRl);
|
||||
|
||||
TypeId
|
||||
TcpRl::GetTypeId (void)
|
||||
{
|
||||
static TypeId tid = TypeId ("ns3::TcpRl")
|
||||
.SetParent<TcpRlBase> ()
|
||||
.SetGroupName ("Internet")
|
||||
.AddConstructor<TcpRl> ()
|
||||
.AddAttribute ("Reward", "Reward when increasing congestion window.",
|
||||
DoubleValue (1.0),
|
||||
MakeDoubleAccessor (&TcpRl::m_reward),
|
||||
MakeDoubleChecker<double> ())
|
||||
.AddAttribute ("Penalty", "Reward when increasing congestion window.",
|
||||
DoubleValue (-10.0),
|
||||
MakeDoubleAccessor (&TcpRl::m_penalty),
|
||||
MakeDoubleChecker<double> ())
|
||||
;
|
||||
return tid;
|
||||
}
|
||||
|
||||
TcpRl::TcpRl (void)
|
||||
: TcpRlBase ()
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
}
|
||||
|
||||
TcpRl::TcpRl (const TcpRl& sock)
|
||||
: TcpRlBase (sock)
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
}
|
||||
|
||||
TcpRl::~TcpRl (void)
|
||||
{
|
||||
}
|
||||
|
||||
std::string
|
||||
TcpRl::GetName () const
|
||||
{
|
||||
return "TcpRl";
|
||||
}
|
||||
|
||||
void
|
||||
TcpRl::CreateGymEnv()
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
Ptr<TcpEventGymEnv> env = CreateObject<TcpEventGymEnv>();
|
||||
env->SetSocketUuid(TcpRlBase::GenerateUuid());
|
||||
env->SetReward(m_reward);
|
||||
env->SetPenalty(m_penalty);
|
||||
m_tcpGymEnv = env;
|
||||
|
||||
ConnectSocketCallbacks();
|
||||
}
|
||||
|
||||
|
||||
NS_OBJECT_ENSURE_REGISTERED (TcpRlTimeBased);
|
||||
|
||||
TypeId
|
||||
TcpRlTimeBased::GetTypeId (void)
|
||||
{
|
||||
static TypeId tid = TypeId ("ns3::TcpRlTimeBased")
|
||||
.SetParent<TcpRlBase> ()
|
||||
.SetGroupName ("Internet")
|
||||
.AddConstructor<TcpRlTimeBased> ()
|
||||
.AddAttribute ("Duration",
|
||||
"Simulation Duration. Default: 10000ms",
|
||||
TimeValue (MilliSeconds (10000)),
|
||||
MakeTimeAccessor (&TcpRlTimeBased::m_duration),
|
||||
MakeTimeChecker ())
|
||||
.AddAttribute ("StepTime",
|
||||
"Step interval used in TCP env. Default: 100ms",
|
||||
TimeValue (MilliSeconds (100)),
|
||||
MakeTimeAccessor (&TcpRlTimeBased::m_timeStep),
|
||||
MakeTimeChecker ())
|
||||
.AddAttribute ("Reward", "Reward for increasing congestion window.",
|
||||
DoubleValue (1.0),
|
||||
MakeDoubleAccessor (&TcpRlTimeBased::m_reward),
|
||||
MakeDoubleChecker<double> ())
|
||||
.AddAttribute ("Penalty", "Penalty for increasing congestion window.",
|
||||
DoubleValue (-1.0),
|
||||
MakeDoubleAccessor (&TcpRlTimeBased::m_penalty),
|
||||
MakeDoubleChecker<double> ())
|
||||
;
|
||||
return tid;
|
||||
}
|
||||
|
||||
TcpRlTimeBased::TcpRlTimeBased (void)
|
||||
: TcpRlBase ()
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
}
|
||||
|
||||
TcpRlTimeBased::TcpRlTimeBased (const TcpRlTimeBased& sock)
|
||||
: TcpRlBase (sock)
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
}
|
||||
|
||||
TcpRlTimeBased::~TcpRlTimeBased (void)
|
||||
{
|
||||
}
|
||||
|
||||
std::string
|
||||
TcpRlTimeBased::GetName () const
|
||||
{
|
||||
return "TcpRlTimeBased";
|
||||
}
|
||||
|
||||
void
|
||||
TcpRlTimeBased::CreateGymEnv()
|
||||
{
|
||||
NS_LOG_FUNCTION (this);
|
||||
Ptr<TcpTimeStepGymEnv> env = CreateObject<TcpTimeStepGymEnv> ();
|
||||
env->SetSocketUuid(TcpRlBase::GenerateUuid());
|
||||
env->SetDuration(m_duration);
|
||||
env->SetTimeStep(m_timeStep);
|
||||
env->SetReward(m_reward);
|
||||
env->SetPenalty(m_penalty);
|
||||
m_tcpGymEnv = env;
|
||||
|
||||
ConnectSocketCallbacks();
|
||||
}
|
||||
|
||||
} // namespace ns3
|
||||
127
tcp-rl.h
Normal file
127
tcp-rl.h
Normal file
@ -0,0 +1,127 @@
|
||||
/* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */
|
||||
/*
|
||||
* Copyright (c) 2018 Technische Universität Berlin
|
||||
*
|
||||
* This program is free software; you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License version 2 as
|
||||
* published by the Free Software Foundation;
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program; if not, write to the Free Software
|
||||
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
|
||||
*
|
||||
* Author: Piotr Gawlowicz <gawlowicz@tkn.tu-berlin.de>
|
||||
*/
|
||||
|
||||
#ifndef TCP_RL_H
|
||||
#define TCP_RL_H
|
||||
|
||||
#include "ns3/tcp-congestion-ops.h"
|
||||
#include "ns3/opengym-module.h"
|
||||
#include "ns3/tcp-socket-base.h"
|
||||
|
||||
namespace ns3 {
|
||||
|
||||
class TcpSocketBase;
|
||||
class Time;
|
||||
class TcpGymEnv;
|
||||
|
||||
|
||||
// used to get pointer to Congestion Algorithm
|
||||
class TcpSocketDerived : public TcpSocketBase
|
||||
{
|
||||
public:
|
||||
static TypeId GetTypeId (void);
|
||||
virtual TypeId GetInstanceTypeId () const;
|
||||
|
||||
TcpSocketDerived (void);
|
||||
virtual ~TcpSocketDerived (void);
|
||||
|
||||
Ptr<TcpCongestionOps> GetCongestionControlAlgorithm ();
|
||||
};
|
||||
|
||||
|
||||
class TcpRlBase : public TcpCongestionOps
|
||||
{
|
||||
public:
|
||||
/**
|
||||
* \brief Get the type ID.
|
||||
* \return the object TypeId
|
||||
*/
|
||||
static TypeId GetTypeId (void);
|
||||
|
||||
TcpRlBase ();
|
||||
|
||||
/**
|
||||
* \brief Copy constructor.
|
||||
* \param sock object to copy.
|
||||
*/
|
||||
TcpRlBase (const TcpRlBase& sock);
|
||||
|
||||
~TcpRlBase ();
|
||||
|
||||
virtual std::string GetName () const;
|
||||
virtual uint32_t GetSsThresh (Ptr<const TcpSocketState> tcb, uint32_t bytesInFlight);
|
||||
virtual void IncreaseWindow (Ptr<TcpSocketState> tcb, uint32_t segmentsAcked);
|
||||
virtual void PktsAcked (Ptr<TcpSocketState> tcb, uint32_t segmentsAcked, const Time& rtt);
|
||||
virtual void CongestionStateSet (Ptr<TcpSocketState> tcb, const TcpSocketState::TcpCongState_t newState);
|
||||
virtual void CwndEvent (Ptr<TcpSocketState> tcb, const TcpSocketState::TcpCAEvent_t event);
|
||||
virtual Ptr<TcpCongestionOps> Fork ();
|
||||
|
||||
protected:
|
||||
static uint64_t GenerateUuid ();
|
||||
virtual void CreateGymEnv();
|
||||
void ConnectSocketCallbacks();
|
||||
|
||||
// OpenGymEnv interface
|
||||
Ptr<TcpSocketBase> m_tcpSocket;
|
||||
Ptr<TcpGymEnv> m_tcpGymEnv;
|
||||
};
|
||||
|
||||
|
||||
class TcpRl : public TcpRlBase
|
||||
{
|
||||
public:
|
||||
static TypeId GetTypeId (void);
|
||||
|
||||
TcpRl ();
|
||||
TcpRl (const TcpRl& sock);
|
||||
~TcpRl ();
|
||||
|
||||
virtual std::string GetName () const;
|
||||
private:
|
||||
virtual void CreateGymEnv();
|
||||
// OpenGymEnv env
|
||||
float m_reward {1.0};
|
||||
float m_penalty {-100.0};
|
||||
};
|
||||
|
||||
|
||||
class TcpRlTimeBased : public TcpRlBase
|
||||
{
|
||||
public:
|
||||
static TypeId GetTypeId (void);
|
||||
|
||||
TcpRlTimeBased ();
|
||||
TcpRlTimeBased (const TcpRlTimeBased& sock);
|
||||
~TcpRlTimeBased ();
|
||||
|
||||
virtual std::string GetName () const;
|
||||
|
||||
private:
|
||||
virtual void CreateGymEnv();
|
||||
|
||||
Time m_duration;
|
||||
Time m_timeStep;
|
||||
float m_reward;
|
||||
float m_penalty;
|
||||
};
|
||||
|
||||
} // namespace ns3
|
||||
|
||||
#endif /* TCP_RL_H */
|
||||
138
tcp_base.py
Normal file
138
tcp_base.py
Normal file
@ -0,0 +1,138 @@
|
||||
__author__ = "Piotr Gawlowicz"
|
||||
__copyright__ = "Copyright (c) 2018, Technische Universität Berlin"
|
||||
__version__ = "0.1.0"
|
||||
__email__ = "gawlowicz@tkn.tu-berlin.de"
|
||||
|
||||
|
||||
class Tcp(object):
|
||||
"""docstring for Tcp"""
|
||||
def __init__(self):
|
||||
super(Tcp, self).__init__()
|
||||
|
||||
def set_spaces(self, obs, act):
|
||||
self.obsSpace = obs
|
||||
self.actSpace = act
|
||||
|
||||
def get_action(self, obs, reward, done, info):
|
||||
pass
|
||||
|
||||
|
||||
class TcpEventBased(Tcp):
|
||||
"""docstring for TcpEventBased"""
|
||||
def __init__(self):
|
||||
super(TcpEventBased, self).__init__()
|
||||
|
||||
def get_action(self, obs, reward, done, info):
|
||||
# unique socket ID
|
||||
socketUuid = obs[0]
|
||||
# TCP env type: event-based = 0 / time-based = 1
|
||||
envType = obs[1]
|
||||
# sim time in us
|
||||
simTime_us = obs[2]
|
||||
# unique node ID
|
||||
nodeId = obs[3]
|
||||
# current ssThreshold
|
||||
ssThresh = obs[4]
|
||||
# current contention window size
|
||||
cWnd = obs[5]
|
||||
# segment size
|
||||
segmentSize = obs[6]
|
||||
# number of acked segments
|
||||
segmentsAcked = obs[7]
|
||||
# estimated bytes in flight
|
||||
bytesInFlight = obs[8]
|
||||
# last estimation of RTT
|
||||
lastRtt_us = obs[9]
|
||||
# min value of RTT
|
||||
minRtt_us = obs[10]
|
||||
# function from Congestion Algorithm (CA) interface:
|
||||
# GET_SS_THRESH = 0 (packet loss),
|
||||
# INCREASE_WINDOW (packet acked),
|
||||
# PKTS_ACKED (unused),
|
||||
# CONGESTION_STATE_SET (unused),
|
||||
# CWND_EVENT (unused),
|
||||
calledFunc = obs[11]
|
||||
# Congetsion Algorithm (CA) state:
|
||||
# CA_OPEN = 0,
|
||||
# CA_DISORDER,
|
||||
# CA_CWR,
|
||||
# CA_RECOVERY,
|
||||
# CA_LOSS,
|
||||
# CA_LAST_STATE
|
||||
caState = obs[12]
|
||||
# Congetsion Algorithm (CA) event:
|
||||
# CA_EVENT_TX_START = 0,
|
||||
# CA_EVENT_CWND_RESTART,
|
||||
# CA_EVENT_COMPLETE_CWR,
|
||||
# CA_EVENT_LOSS,
|
||||
# CA_EVENT_ECN_NO_CE,
|
||||
# CA_EVENT_ECN_IS_CE,
|
||||
# CA_EVENT_DELAYED_ACK,
|
||||
# CA_EVENT_NON_DELAYED_ACK,
|
||||
caEvent = obs[13]
|
||||
# ECN state:
|
||||
# ECN_DISABLED = 0,
|
||||
# ECN_IDLE,
|
||||
# ECN_CE_RCVD,
|
||||
# ECN_SENDING_ECE,
|
||||
# ECN_ECE_RCVD,
|
||||
# ECN_CWR_SENT
|
||||
ecnState = obs[14]
|
||||
|
||||
# compute new values
|
||||
new_cWnd = 10 * segmentSize
|
||||
new_ssThresh = 5 * segmentSize
|
||||
|
||||
# return actions
|
||||
actions = [new_ssThresh, new_cWnd]
|
||||
|
||||
return actions
|
||||
|
||||
|
||||
class TcpTimeBased(Tcp):
|
||||
"""docstring for TcpTimeBased"""
|
||||
def __init__(self):
|
||||
super(TcpTimeBased, self).__init__()
|
||||
|
||||
def get_action(self, obs, reward, done, info):
|
||||
# unique socket ID
|
||||
socketUuid = obs[0]
|
||||
# TCP env type: event-based = 0 / time-based = 1
|
||||
envType = obs[1]
|
||||
# sim time in us
|
||||
simTime_us = obs[2]
|
||||
# unique node ID
|
||||
nodeId = obs[3]
|
||||
# current ssThreshold
|
||||
ssThresh = obs[4]
|
||||
# current congestion window size
|
||||
cWnd = obs[5]
|
||||
# segment size
|
||||
segmentSize = obs[6]
|
||||
# bytesInFlightSum
|
||||
bytesInFlightSum = obs[7]
|
||||
# bytesInFlightAvg
|
||||
bytesInFlightAvg = obs[8]
|
||||
# segmentsAckedSum
|
||||
segmentsAckedSum = obs[9]
|
||||
# segmentsAckedAvg
|
||||
segmentsAckedAvg = obs[10]
|
||||
# avgRtt
|
||||
avgRtt = obs[11]
|
||||
# minRtt
|
||||
minRtt = obs[12]
|
||||
# avgInterTx
|
||||
avgInterTx = obs[13]
|
||||
# avgInterRx
|
||||
avgInterRx = obs[14]
|
||||
# throughput
|
||||
throughput = obs[15]
|
||||
|
||||
# compute new values
|
||||
new_cWnd = 10 * segmentSize
|
||||
new_ssThresh = 5 * segmentSize
|
||||
|
||||
# return actions
|
||||
actions = [new_ssThresh, new_cWnd]
|
||||
|
||||
return actions
|
||||
Loading…
x
Reference in New Issue
Block a user