Upload files

This commit is contained in:
Kaushik Narayan R 2022-05-14 23:02:03 +05:30
parent 2d77152b5d
commit 44717b8f3e
7 changed files with 2162 additions and 0 deletions

258
TCP-RL-Agent.py Executable file
View File

@ -0,0 +1,258 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import sys
import argparse
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import tensorflow as tf
from ns3gym import ns3env
from tcp_base import TcpTimeBased, TcpEventBased
try:
w_file = open('run.log', 'w')
except:
w_file = sys.stdout
parser = argparse.ArgumentParser(description='Start simulation script on/off')
parser.add_argument('--start',
type=int,
default=1,
help='Start ns-3 simulation script 0/1, Default: 1')
parser.add_argument('--iterations',
type=int,
default=1,
help='Number of iterations, Default: 1')
parser.add_argument('--steps',
type=int,
default=100,
help='Number of steps, Default 100')
parser.add_argument('--debug',
type=int,
default=0,
help='Show debug output 0/1, Default 0')
args = parser.parse_args()
startSim = bool(args.start)
iterationNum = int(args.iterations)
maxSteps = int(args.steps)
port = 5555
simTime = maxSteps / 10.0 # seconds
stepTime = simTime / 200.0 # seconds
seed = 12
simArgs = {"--duration": simTime,}
dashes = "-"*18
input("[{}Press Enter to start{}]".format(dashes, dashes))
# create environment
env = ns3env.Ns3Env(port=port, stepTime=stepTime, startSim=startSim, simSeed=seed, simArgs=simArgs)
ob_space = env.observation_space
ac_space = env.action_space
# TODO: right now, the next action is selected inside the loop, rather than using get_action.
# this is because we use the decaying epsilon-greedy algo which needs to use the live model
# somehow change or put that logic in an `RLTCP` class that inherits from the Tcp class, like in tcp_base.py,
# then move the class to tcp_base.py and use that agent here
def get_agent(state):
socketUuid = state[0]
tcpEnvType = state[1]
tcpAgent = get_agent.tcpAgents.get(socketUuid, None)
if tcpAgent is None:
# get a new agent based on the selected env type
if tcpEnvType == 0:
# event-based = 0
tcpAgent = TcpEventBased()
else:
# time-based = 1
tcpAgent = TcpTimeBased()
tcpAgent.set_spaces(get_agent.ob_space, get_agent.ac_space)
get_agent.tcpAgents[socketUuid] = tcpAgent
return tcpAgent
# initialize agent variables
# (useless until the above todo is fixed)
get_agent.tcpAgents = {}
get_agent.ob_space = ob_space
get_agent.ac_space = ac_space
def modeler(input_size, output_size):
"""
Designs a fully connected neural network.
"""
model = tf.keras.Sequential()
# input layer
model.add(tf.keras.layers.Dense((input_size + output_size) // 2, input_shape=(input_size,), activation='relu'))
# hidden layer of mean size of input and output
# model.add(tf.keras.layers.Dense((input_size + output_size) // 2, activation='relu'))
# output layer
# maps previous layer of input_size units to output_size units
# this is a classifier network
model.add(tf.keras.layers.Dense(output_size, activation='softmax'))
return model
state_size = ob_space.shape[0] - 4 # ignoring 4 env attributes
action_size = 3
action_mapping = {} # dict faster than list
action_mapping[0] = 0
action_mapping[1] = 600
action_mapping[2] = -150
# build model
model = modeler(state_size, action_size)
model.compile(
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-2),
loss='categorical_crossentropy',
metrics=['accuracy']
)
model.summary()
# initialize decaying epsilon-greedy algorithm
# fine-tune to ensure balance of exploration and exploitation
epsilon = 1.0
epsilon_decay_param = iterationNum * 5
min_epsilon = 0.1
epsilon_decay = (((epsilon_decay_param*maxSteps) - 1.0) / (epsilon_decay_param*maxSteps))
# initialize Q-learning's discount factor
discount_factor = 0.95
rewardsum = 0
rew_history = []
cWnd_history = []
pred_cWnd_history = []
rtt_history = []
done = False
pretty_slash = ['\\', '|', '/', '-']
for iteration in range(iterationNum):
# set initial state
state = env.reset()
# ignore env attributes: socketID, env type, sim time, nodeID
state = state[4:]
cWnd = state[1]
init_cWnd = cWnd
state = np.reshape(state, [1, state_size])
try:
for step in range(maxSteps):
pretty_index = step % 4
print("\r{}\r[{}] Logging to file {} {}".format(
' '*(25+len(w_file.name)),
pretty_slash[pretty_index],
w_file.name,
'.'*(pretty_index+1)
), end='')
print("[+] Step: {}".format(step+1), file=w_file)
# Epsilon-greedy selection
if step == 0 or np.random.rand(1) < epsilon:
# explore new situation
action_index = np.random.randint(0, action_size)
print("\t[*] Random exploration. Selected action: {}".format(action_index), file=w_file)
else:
# exploit gained knowledge
action_index = np.argmax(model.predict(state)[0])
print("\t[*] Exploiting gained knowledge. Selected action: {}".format(action_index), file=w_file)
# Calculate action
# Note: prevent new_cWnd from falling too low to avoid negative values
new_cWnd = cWnd + action_mapping[action_index]
new_ssThresh = int(cWnd/2)
actions = [new_ssThresh, new_cWnd]
# Take action step on environment and get feedback
next_state, reward, done, _ = env.step(actions)
rewardsum += reward
next_state = next_state[4:]
cWnd = next_state[1]
rtt = next_state[7]
print("\t[#] Next state: ", next_state, file=w_file)
print("\t[!] Reward: ", reward, file=w_file)
next_state = np.reshape(next_state, [1, state_size])
# Train incrementally
# DQN - function approximation using neural networks
target = reward
if not done:
target = (reward + discount_factor * np.amax(model.predict(next_state)[0]))
target_f = model.predict(state)
target_f[0][action_index] = target
model.fit(state, target_f, epochs=1, verbose=0)
# Update state
state = next_state
if done:
print("[X] Stopping: step: {}, reward sum: {}, epsilon: {:.2}"
.format(step+1, rewardsum, epsilon),
file=w_file)
break
if epsilon > min_epsilon:
epsilon *= epsilon_decay
# Record information
rew_history.append(rewardsum)
rtt_history.append(rtt)
cWnd_history.append(cWnd)
pred_cWnd_history.append(new_cWnd)
print("\n[O] Iteration over.", file=w_file)
print("[-] Final epsilon value: ", epsilon, file=w_file)
print("[-] Final reward sum: ", rewardsum, file=w_file)
print()
finally:
print()
if iteration+1 == iterationNum:
break
# if str(input("[?] Continue to next iteration? [Y/n]: ") or "Y").lower() != "y":
# break
mpl.rcdefaults()
mpl.rcParams.update({'font.size': 12})
fig, ax = plt.subplots(2, 2, figsize=(4,2))
plt.tight_layout(pad=0.3)
ax[0, 0].plot(range(len(cWnd_history)), cWnd_history, marker="", linestyle="-")
ax[0, 0].set_title('Congestion windows')
ax[0, 0].set_xlabel('Steps')
ax[0, 0].set_ylabel('Actual CWND')
ax[0, 1].plot(range(len(pred_cWnd_history)), pred_cWnd_history, marker="", linestyle="-")
ax[0, 1].set_title('Predicted values')
ax[0, 1].set_xlabel('Steps')
ax[0, 1].set_ylabel('Predicted CWND')
ax[1, 0].plot(range(len(rtt_history)), rtt_history, marker="", linestyle="-")
ax[1, 0].set_title('RTT over time')
ax[1, 0].set_xlabel('Steps')
ax[1, 0].set_ylabel('RTT (microseconds)')
ax[1, 1].plot(range(len(rew_history)), rew_history, marker="", linestyle="-")
ax[1, 1].set_title('Reward sum plot')
ax[1, 1].set_xlabel('Steps')
ax[1, 1].set_ylabel('Reward sum')
plt.show()

333
sim.cc Normal file
View File

@ -0,0 +1,333 @@
/* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */
/*
* Copyright (c) 2018 Piotr Gawlowicz
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 2 as
* published by the Free Software Foundation;
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
*
* Author: Piotr Gawlowicz <gawlowicz.p@gmail.com>
* Based on script: ./examples/tcp/tcp-variants-comparison.cc
*
* Topology:
*
* Right Leafs (Clients) Left Leafs (Sinks)
* | \ / |
* | \ bottleneck / |
* | R0--------------R1 |
* | / \ |
* | access / \ access |
* N ----------- --------N
*/
#include <iostream>
#include <fstream>
#include <string>
#include "ns3/core-module.h"
#include "ns3/network-module.h"
#include "ns3/internet-module.h"
#include "ns3/point-to-point-module.h"
#include "ns3/point-to-point-layout-module.h"
#include "ns3/applications-module.h"
#include "ns3/error-model.h"
#include "ns3/tcp-header.h"
#include "ns3/enum.h"
#include "ns3/event-id.h"
#include "ns3/flow-monitor-helper.h"
#include "ns3/ipv4-global-routing-helper.h"
#include "ns3/traffic-control-module.h"
#include "ns3/opengym-module.h"
#include "tcp-rl.h"
using namespace ns3;
NS_LOG_COMPONENT_DEFINE ("TcpVariantsComparison");
static std::vector<uint32_t> rxPkts;
static void
CountRxPkts(uint32_t sinkId, Ptr<const Packet> packet, const Address & srcAddr)
{
rxPkts[sinkId]++;
}
static void
PrintRxCount()
{
uint32_t size = rxPkts.size();
NS_LOG_UNCOND("RxPkts:");
for (uint32_t i=0; i<size; i++){
NS_LOG_UNCOND("---SinkId: "<< i << " RxPkts: " << rxPkts.at(i));
}
}
int main (int argc, char *argv[])
{
uint32_t openGymPort = 5555;
double tcpEnvTimeStep = 0.1;
uint32_t nLeaf = 1;
std::string transport_prot = "TcpRl";
double error_p = 0.0;
std::string bottleneck_bandwidth = "2Mbps";
std::string bottleneck_delay = "0.01ms";
std::string access_bandwidth = "10Mbps";
std::string access_delay = "20ms";
std::string prefix_file_name = "TcpVariantsComparison";
uint64_t data_mbytes = 0;
uint32_t mtu_bytes = 400;
double duration = 10.0;
uint32_t run = 0;
bool flow_monitor = false;
bool sack = true;
std::string queue_disc_type = "ns3::PfifoFastQueueDisc";
std::string recovery = "ns3::TcpClassicRecovery";
CommandLine cmd;
// required parameters for OpenGym interface
cmd.AddValue ("openGymPort", "Port number for OpenGym env. Default: 5555", openGymPort);
cmd.AddValue ("simSeed", "Seed for random generator. Default: 1", run);
cmd.AddValue ("envTimeStep", "Time step interval for time-based TCP env [s]. Default: 0.1s", tcpEnvTimeStep);
// other parameters
cmd.AddValue ("nLeaf", "Number of left and right side leaf nodes", nLeaf);
cmd.AddValue ("transport_prot", "Transport protocol to use: TcpNewReno, "
"TcpHybla, TcpHighSpeed, TcpHtcp, TcpVegas, TcpScalable, TcpVeno, "
"TcpBic, TcpYeah, TcpIllinois, TcpWestwood, TcpWestwoodPlus, TcpLedbat, "
"TcpLp, TcpRl, TcpRlTimeBased", transport_prot);
cmd.AddValue ("error_p", "Packet error rate", error_p);
cmd.AddValue ("bottleneck_bandwidth", "Bottleneck bandwidth", bottleneck_bandwidth);
cmd.AddValue ("bottleneck_delay", "Bottleneck delay", bottleneck_delay);
cmd.AddValue ("access_bandwidth", "Access link bandwidth", access_bandwidth);
cmd.AddValue ("access_delay", "Access link delay", access_delay);
cmd.AddValue ("prefix_name", "Prefix of output trace file", prefix_file_name);
cmd.AddValue ("data", "Number of Megabytes of data to transmit", data_mbytes);
cmd.AddValue ("mtu", "Size of IP packets to send in bytes", mtu_bytes);
cmd.AddValue ("duration", "Time to allow flows to run in seconds", duration);
cmd.AddValue ("run", "Run index (for setting repeatable seeds)", run);
cmd.AddValue ("flow_monitor", "Enable flow monitor", flow_monitor);
cmd.AddValue ("queue_disc_type", "Queue disc type for gateway (e.g. ns3::CoDelQueueDisc)", queue_disc_type);
cmd.AddValue ("sack", "Enable or disable SACK option", sack);
cmd.AddValue ("recovery", "Recovery algorithm type to use (e.g., ns3::TcpPrrRecovery", recovery);
cmd.Parse (argc, argv);
transport_prot = std::string ("ns3::") + transport_prot;
SeedManager::SetSeed (1);
SeedManager::SetRun (run);
NS_LOG_UNCOND("Ns3Env parameters:");
if (transport_prot.compare ("ns3::TcpRl") == 0 or transport_prot.compare ("ns3::TcpRlTimeBased") == 0)
{
NS_LOG_UNCOND("--openGymPort: " << openGymPort);
} else {
NS_LOG_UNCOND("--openGymPort: No OpenGym");
}
NS_LOG_UNCOND("--seed: " << run);
NS_LOG_UNCOND("--Tcp version: " << transport_prot);
// OpenGym Env --- has to be created before any other thing
Ptr<OpenGymInterface> openGymInterface;
if (transport_prot.compare ("ns3::TcpRl") == 0)
{
openGymInterface = OpenGymInterface::Get(openGymPort);
Config::SetDefault ("ns3::TcpRl::Reward", DoubleValue (2.0)); // Reward when increasing congestion window
Config::SetDefault ("ns3::TcpRl::Penalty", DoubleValue (-30.0)); // Penalty when decreasing congestion window
}
if (transport_prot.compare ("ns3::TcpRlTimeBased") == 0)
{
openGymInterface = OpenGymInterface::Get(openGymPort);
Config::SetDefault ("ns3::TcpRlTimeBased::StepTime", TimeValue (Seconds(tcpEnvTimeStep))); // Time step of env
Config::SetDefault ("ns3::TcpRlTimeBased::Duration", TimeValue (Seconds(duration))); // Duration of env sim
Config::SetDefault ("ns3::TcpRlTimeBased::Reward", DoubleValue (1.0)); // Reward
Config::SetDefault ("ns3::TcpRlTimeBased::Penalty", DoubleValue (-1.0)); // Penalty
}
// Calculate the ADU size
Header* temp_header = new Ipv4Header ();
uint32_t ip_header = temp_header->GetSerializedSize ();
NS_LOG_LOGIC ("IP Header size is: " << ip_header);
delete temp_header;
temp_header = new TcpHeader ();
uint32_t tcp_header = temp_header->GetSerializedSize ();
NS_LOG_LOGIC ("TCP Header size is: " << tcp_header);
delete temp_header;
uint32_t tcp_adu_size = mtu_bytes - 20 - (ip_header + tcp_header);
NS_LOG_LOGIC ("TCP ADU size is: " << tcp_adu_size);
// Set the simulation start and stop time
double start_time = 0.1; // it takes some time to initialise some variables, idk why
double stop_time = start_time + duration;
// 4 MB of TCP buffer
Config::SetDefault ("ns3::TcpSocket::RcvBufSize", UintegerValue (1 << 21));
Config::SetDefault ("ns3::TcpSocket::SndBufSize", UintegerValue (1 << 21));
Config::SetDefault ("ns3::TcpSocketBase::Sack", BooleanValue (sack));
// no. of packets received before an ACK is sent (why is the default 2?)
Config::SetDefault ("ns3::TcpSocket::DelAckCount", UintegerValue (2));
Config::SetDefault ("ns3::TcpL4Protocol::RecoveryType",
TypeIdValue (TypeId::LookupByName (recovery)));
// Select TCP variant
if (transport_prot.compare ("ns3::TcpWestwoodPlus") == 0)
{
// TcpWestwoodPlus is not an actual TypeId name; we need TcpWestwood here
Config::SetDefault ("ns3::TcpL4Protocol::SocketType", TypeIdValue (TcpWestwood::GetTypeId ()));
// the default protocol type in ns3::TcpWestwood is WESTWOOD
Config::SetDefault ("ns3::TcpWestwood::ProtocolType", EnumValue (TcpWestwood::WESTWOODPLUS));
}
else
{
TypeId tcpTid;
NS_ABORT_MSG_UNLESS (TypeId::LookupByNameFailSafe (transport_prot, &tcpTid), "TypeId " << transport_prot << " not found");
Config::SetDefault ("ns3::TcpL4Protocol::SocketType", TypeIdValue (TypeId::LookupByName (transport_prot)));
}
// Configure the error model
// Here we use RateErrorModel with packet error rate
Ptr<UniformRandomVariable> uv = CreateObject<UniformRandomVariable> ();
uv->SetStream (50);
RateErrorModel error_model;
error_model.SetRandomVariable (uv);
error_model.SetUnit (RateErrorModel::ERROR_UNIT_PACKET);
error_model.SetRate (error_p);
// Create the point-to-point link helpers
PointToPointHelper bottleNeckLink;
bottleNeckLink.SetDeviceAttribute ("DataRate", StringValue (bottleneck_bandwidth));
bottleNeckLink.SetChannelAttribute ("Delay", StringValue (bottleneck_delay));
//bottleNeckLink.SetDeviceAttribute ("ReceiveErrorModel", PointerValue (&error_model));
PointToPointHelper pointToPointLeaf;
pointToPointLeaf.SetDeviceAttribute ("DataRate", StringValue (access_bandwidth));
pointToPointLeaf.SetChannelAttribute ("Delay", StringValue (access_delay));
PointToPointDumbbellHelper d (nLeaf, pointToPointLeaf,
nLeaf, pointToPointLeaf,
bottleNeckLink);
// Install IP stack
InternetStackHelper stack;
stack.InstallAll ();
// Traffic Control
TrafficControlHelper tchPfifo;
tchPfifo.SetRootQueueDisc ("ns3::PfifoFastQueueDisc");
TrafficControlHelper tchCoDel;
tchCoDel.SetRootQueueDisc ("ns3::CoDelQueueDisc");
DataRate access_b (access_bandwidth);
DataRate bottle_b (bottleneck_bandwidth);
Time access_d (access_delay);
Time bottle_d (bottleneck_delay);
uint32_t size = static_cast<uint32_t>((std::min (access_b, bottle_b).GetBitRate () / 8) *
((access_d + bottle_d + access_d) * 2).GetSeconds ());
Config::SetDefault ("ns3::PfifoFastQueueDisc::MaxSize",
QueueSizeValue (QueueSize (QueueSizeUnit::PACKETS, size / mtu_bytes)));
Config::SetDefault ("ns3::CoDelQueueDisc::MaxSize",
QueueSizeValue (QueueSize (QueueSizeUnit::BYTES, size)));
if (queue_disc_type.compare ("ns3::PfifoFastQueueDisc") == 0)
{
tchPfifo.Install (d.GetLeft()->GetDevice(1));
tchPfifo.Install (d.GetRight()->GetDevice(1));
}
else if (queue_disc_type.compare ("ns3::CoDelQueueDisc") == 0)
{
tchCoDel.Install (d.GetLeft()->GetDevice(1));
tchCoDel.Install (d.GetRight()->GetDevice(1));
}
else
{
NS_FATAL_ERROR ("Queue not recognized. Allowed values are ns3::CoDelQueueDisc or ns3::PfifoFastQueueDisc");
}
// Assign IP Addresses
d.AssignIpv4Addresses (Ipv4AddressHelper ("10.1.1.0", "255.255.255.0"),
Ipv4AddressHelper ("10.2.1.0", "255.255.255.0"),
Ipv4AddressHelper ("10.3.1.0", "255.255.255.0"));
NS_LOG_INFO ("Initialize Global Routing.");
Ipv4GlobalRoutingHelper::PopulateRoutingTables ();
// Install apps in left and right nodes
uint16_t port = 50000;
Address sinkLocalAddress (InetSocketAddress (Ipv4Address::GetAny (), port));
PacketSinkHelper sinkHelper ("ns3::TcpSocketFactory", sinkLocalAddress);
ApplicationContainer sinkApps;
for (uint32_t i = 0; i < d.RightCount (); ++i)
{
sinkHelper.SetAttribute ("Protocol", TypeIdValue (TcpSocketFactory::GetTypeId ()));
sinkApps.Add (sinkHelper.Install (d.GetRight (i)));
}
sinkApps.Start (Seconds (0.0));
sinkApps.Stop (Seconds (stop_time));
for (uint32_t i = 0; i < d.LeftCount (); ++i)
{
// Create an on/off app sending packets to the left side
AddressValue remoteAddress (InetSocketAddress (d.GetRightIpv4Address (i), port));
Config::SetDefault ("ns3::TcpSocket::SegmentSize", UintegerValue (tcp_adu_size));
BulkSendHelper ftp ("ns3::TcpSocketFactory", Address ());
ftp.SetAttribute ("Remote", remoteAddress);
ftp.SetAttribute ("SendSize", UintegerValue (tcp_adu_size));
ftp.SetAttribute ("MaxBytes", UintegerValue (data_mbytes * 1000000));
ApplicationContainer clientApp = ftp.Install (d.GetLeft (i));
clientApp.Start (Seconds (start_time * i)); // Start after sink
clientApp.Stop (Seconds (stop_time - 3)); // Stop before the sink
}
// Flow monitor
FlowMonitorHelper flowHelper;
if (flow_monitor)
{
flowHelper.InstallAll ();
}
// Count RX packets
for (uint32_t i = 0; i < d.RightCount (); ++i)
{
rxPkts.push_back(0);
Ptr<PacketSink> pktSink = DynamicCast<PacketSink>(sinkApps.Get(i));
pktSink->TraceConnectWithoutContext ("Rx", MakeBoundCallback (&CountRxPkts, i));
}
Simulator::Stop (Seconds (stop_time));
Simulator::Run ();
if (flow_monitor)
{
flowHelper.SerializeToXmlFile (prefix_file_name + ".flowmonitor", true, true);
}
if (transport_prot.compare ("ns3::TcpRl") == 0 or transport_prot.compare ("ns3::TcpRlTimeBased") == 0)
{
openGymInterface->NotifySimulationEnd();
}
PrintRxCount();
Simulator::Destroy ();
return 0;
}

714
tcp-rl-env.cc Normal file
View File

@ -0,0 +1,714 @@
/* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */
/*
* Copyright (c) 2018 Technische Universität Berlin
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 2 as
* published by the Free Software Foundation;
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
*
* Author: Piotr Gawlowicz <gawlowicz@tkn.tu-berlin.de>
*/
#include "tcp-rl-env.h"
#include "ns3/tcp-header.h"
#include "ns3/object.h"
#include "ns3/core-module.h"
#include "ns3/log.h"
#include "ns3/simulator.h"
#include "ns3/tcp-socket-base.h"
#include <vector>
#include <numeric>
namespace ns3 {
NS_LOG_COMPONENT_DEFINE ("ns3::TcpGymEnv");
NS_OBJECT_ENSURE_REGISTERED (TcpGymEnv);
TcpGymEnv::TcpGymEnv ()
{
NS_LOG_FUNCTION (this);
SetOpenGymInterface(OpenGymInterface::Get());
}
TcpGymEnv::~TcpGymEnv ()
{
NS_LOG_FUNCTION (this);
}
TypeId
TcpGymEnv::GetTypeId (void)
{
static TypeId tid = TypeId ("ns3::TcpGymEnv")
.SetParent<OpenGymEnv> ()
.SetGroupName ("OpenGym")
;
return tid;
}
void
TcpGymEnv::DoDispose ()
{
NS_LOG_FUNCTION (this);
}
void
TcpGymEnv::SetNodeId(uint32_t id)
{
NS_LOG_FUNCTION (this);
m_nodeId = id;
}
void
TcpGymEnv::SetSocketUuid(uint32_t id)
{
NS_LOG_FUNCTION (this);
m_socketUuid = id;
}
std::string
TcpGymEnv::GetTcpCongStateName(const TcpSocketState::TcpCongState_t state)
{
std::string stateName = "UNKNOWN";
switch(state) {
case TcpSocketState::CA_OPEN:
stateName = "CA_OPEN";
break;
case TcpSocketState::CA_DISORDER:
stateName = "CA_DISORDER";
break;
case TcpSocketState::CA_CWR:
stateName = "CA_CWR";
break;
case TcpSocketState::CA_RECOVERY:
stateName = "CA_RECOVERY";
break;
case TcpSocketState::CA_LOSS:
stateName = "CA_LOSS";
break;
case TcpSocketState::CA_LAST_STATE:
stateName = "CA_LAST_STATE";
break;
default:
stateName = "UNKNOWN";
break;
}
return stateName;
}
std::string
TcpGymEnv::GetTcpCAEventName(const TcpSocketState::TcpCAEvent_t event)
{
std::string eventName = "UNKNOWN";
switch(event) {
case TcpSocketState::CA_EVENT_TX_START:
eventName = "CA_EVENT_TX_START";
break;
case TcpSocketState::CA_EVENT_CWND_RESTART:
eventName = "CA_EVENT_CWND_RESTART";
break;
case TcpSocketState::CA_EVENT_COMPLETE_CWR:
eventName = "CA_EVENT_COMPLETE_CWR";
break;
case TcpSocketState::CA_EVENT_LOSS:
eventName = "CA_EVENT_LOSS";
break;
case TcpSocketState::CA_EVENT_ECN_NO_CE:
eventName = "CA_EVENT_ECN_NO_CE";
break;
case TcpSocketState::CA_EVENT_ECN_IS_CE:
eventName = "CA_EVENT_ECN_IS_CE";
break;
case TcpSocketState::CA_EVENT_DELAYED_ACK:
eventName = "CA_EVENT_DELAYED_ACK";
break;
case TcpSocketState::CA_EVENT_NON_DELAYED_ACK:
eventName = "CA_EVENT_NON_DELAYED_ACK";
break;
default:
eventName = "UNKNOWN";
break;
}
return eventName;
}
/*
Define action space
*/
Ptr<OpenGymSpace>
TcpGymEnv::GetActionSpace()
{
// new_ssThresh
// new_cWnd
uint32_t parameterNum = 2;
float low = 0.0;
float high = 65535;
std::vector<uint32_t> shape = {parameterNum,};
std::string dtype = TypeNameGet<uint32_t> ();
Ptr<OpenGymBoxSpace> box = CreateObject<OpenGymBoxSpace> (low, high, shape, dtype);
NS_LOG_INFO ("MyGetActionSpace: " << box);
return box;
}
/*
Define game over condition
*/
bool
TcpGymEnv::GetGameOver()
{
m_isGameOver = false;
bool test = false;
static float stepCounter = 0.0;
stepCounter += 1;
if (stepCounter == 10 && test) {
m_isGameOver = true;
}
NS_LOG_INFO ("MyGetGameOver: " << m_isGameOver);
return m_isGameOver;
}
/*
Define reward function
*/
float
TcpGymEnv::GetReward()
{
NS_LOG_INFO("MyGetReward: " << m_envReward);
return m_envReward;
}
/*
Define extra info. Optional
*/
std::string
TcpGymEnv::GetExtraInfo()
{
NS_LOG_INFO("MyGetExtraInfo: " << m_info);
return m_info;
}
/*
Execute received actions
*/
bool
TcpGymEnv::ExecuteActions(Ptr<OpenGymDataContainer> action)
{
Ptr<OpenGymBoxContainer<uint32_t> > box = DynamicCast<OpenGymBoxContainer<uint32_t> >(action);
m_new_ssThresh = box->GetValue(0);
m_new_cWnd = box->GetValue(1);
NS_LOG_INFO ("MyExecuteActions: " << action);
return true;
}
NS_OBJECT_ENSURE_REGISTERED (TcpEventGymEnv);
TcpEventGymEnv::TcpEventGymEnv () : TcpGymEnv()
{
NS_LOG_FUNCTION (this);
}
TcpEventGymEnv::~TcpEventGymEnv ()
{
NS_LOG_FUNCTION (this);
}
TypeId
TcpEventGymEnv::GetTypeId (void)
{
static TypeId tid = TypeId ("ns3::TcpEventGymEnv")
.SetParent<TcpGymEnv> ()
.SetGroupName ("OpenGym")
.AddConstructor<TcpEventGymEnv> ()
;
return tid;
}
void
TcpEventGymEnv::DoDispose ()
{
NS_LOG_FUNCTION (this);
}
void
TcpEventGymEnv::SetReward(float value)
{
NS_LOG_FUNCTION (this);
m_reward = value;
}
void
TcpEventGymEnv::SetPenalty(float value)
{
NS_LOG_FUNCTION (this);
m_penalty = value;
}
/*
Define observation space
*/
Ptr<OpenGymSpace>
TcpEventGymEnv::GetObservationSpace()
{
// socket unique ID
// tcp env type: event-based = 0 / time-based = 1
// sim time in us
// node ID
// ssThresh
// cWnd
// segmentSize
// segmentsAcked
// bytesInFlight
// rtt in us
// min rtt in us
// called func
// congestion algorithm (CA) state
// CA event
// ECN state
uint32_t parameterNum = 10;
float low = 0.0;
float high = 1000000000.0;
std::vector<uint32_t> shape = {parameterNum,};
std::string dtype = TypeNameGet<uint64_t> ();
Ptr<OpenGymBoxSpace> box = CreateObject<OpenGymBoxSpace> (low, high, shape, dtype);
NS_LOG_INFO ("MyGetObservationSpace: " << box);
return box;
}
/*
Collect observations
*/
Ptr<OpenGymDataContainer>
TcpEventGymEnv::GetObservation()
{
uint32_t parameterNum = 10;
std::vector<uint32_t> shape = {parameterNum,};
Ptr<OpenGymBoxContainer<uint64_t> > box = CreateObject<OpenGymBoxContainer<uint64_t> >(shape);
box->AddValue(m_socketUuid);
box->AddValue(0);
box->AddValue(Simulator::Now().GetMicroSeconds ());
box->AddValue(m_nodeId);
box->AddValue(m_tcb->m_ssThresh);
box->AddValue(m_tcb->m_cWnd);
box->AddValue(m_tcb->m_segmentSize);
box->AddValue(m_segmentsAcked);
box->AddValue(m_bytesInFlight);
box->AddValue(m_rtt.GetMicroSeconds ());
//box->AddValue(m_tcb->m_minRtt.GetMicroSeconds ());
//box->AddValue(m_calledFunc);
//box->AddValue(m_tcb->m_congState);
//box->AddValue(m_event);
//box->AddValue(m_tcb->m_ecnState);
// Print data
NS_LOG_INFO ("MyGetObservation: " << box);
return box;
}
void
TcpEventGymEnv::TxPktTrace(Ptr<const Packet>, const TcpHeader&, Ptr<const TcpSocketBase>)
{
NS_LOG_FUNCTION (this);
}
void
TcpEventGymEnv::RxPktTrace(Ptr<const Packet>, const TcpHeader&, Ptr<const TcpSocketBase>)
{
NS_LOG_FUNCTION (this);
}
uint32_t
TcpEventGymEnv::GetSsThresh (Ptr<const TcpSocketState> tcb, uint32_t bytesInFlight)
{
NS_LOG_FUNCTION (this);
// pkt was lost, so penalty
m_envReward = m_penalty;
NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " GetSsThresh, BytesInFlight: " << bytesInFlight);
m_calledFunc = CalledFunc_t::GET_SS_THRESH;
m_info = "GetSsThresh";
m_tcb = tcb;
m_bytesInFlight = bytesInFlight;
Notify();
return m_new_ssThresh;
}
void
TcpEventGymEnv::IncreaseWindow (Ptr<TcpSocketState> tcb, uint32_t segmentsAcked)
{
NS_LOG_FUNCTION (this);
// pkt was acked, so reward
m_envReward = m_reward;
NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " IncreaseWindow, SegmentsAcked: " << segmentsAcked);
m_calledFunc = CalledFunc_t::INCREASE_WINDOW;
m_info = "IncreaseWindow";
m_tcb = tcb;
m_segmentsAcked = segmentsAcked;
Notify();
tcb->m_cWnd = m_new_cWnd;
}
void
TcpEventGymEnv::PktsAcked (Ptr<TcpSocketState> tcb, uint32_t segmentsAcked, const Time& rtt)
{
NS_LOG_FUNCTION (this);
NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " PktsAcked, SegmentsAcked: " << segmentsAcked << " Rtt: " << rtt);
m_calledFunc = CalledFunc_t::PKTS_ACKED;
m_info = "PktsAcked";
m_tcb = tcb;
m_segmentsAcked = segmentsAcked;
m_rtt = rtt;
}
void
TcpEventGymEnv::CongestionStateSet (Ptr<TcpSocketState> tcb, const TcpSocketState::TcpCongState_t newState)
{
NS_LOG_FUNCTION (this);
std::string stateName = GetTcpCongStateName(newState);
NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " CongestionStateSet: " << newState << " " << stateName);
m_calledFunc = CalledFunc_t::CONGESTION_STATE_SET;
m_info = "CongestionStateSet";
m_tcb = tcb;
m_newState = newState;
}
void
TcpEventGymEnv::CwndEvent (Ptr<TcpSocketState> tcb, const TcpSocketState::TcpCAEvent_t event)
{
NS_LOG_FUNCTION (this);
std::string eventName = GetTcpCAEventName(event);
NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " CwndEvent: " << event << " " << eventName);
m_calledFunc = CalledFunc_t::CWND_EVENT;
m_info = "CwndEvent";
m_tcb = tcb;
m_event = event;
}
NS_OBJECT_ENSURE_REGISTERED (TcpTimeStepGymEnv);
TcpTimeStepGymEnv::TcpTimeStepGymEnv () : TcpGymEnv()
{
NS_LOG_FUNCTION (this);
}
void
TcpTimeStepGymEnv::ScheduleNextStateRead ()
{
NS_LOG_FUNCTION (this);
Simulator::Schedule (m_timeStep, &TcpTimeStepGymEnv::ScheduleNextStateRead, this);
Notify();
}
TcpTimeStepGymEnv::~TcpTimeStepGymEnv ()
{
NS_LOG_FUNCTION (this);
}
TypeId
TcpTimeStepGymEnv::GetTypeId (void)
{
static TypeId tid = TypeId ("ns3::TcpTimeStepGymEnv")
.SetParent<TcpGymEnv> ()
.SetGroupName ("OpenGym")
.AddConstructor<TcpTimeStepGymEnv> ()
;
return tid;
}
void
TcpTimeStepGymEnv::DoDispose ()
{
NS_LOG_FUNCTION (this);
}
void
TcpTimeStepGymEnv::SetDuration(Time value)
{
NS_LOG_FUNCTION (this);
m_duration = value;
}
void
TcpTimeStepGymEnv::SetTimeStep(Time value)
{
NS_LOG_FUNCTION (this);
m_timeStep = value;
}
void
TcpTimeStepGymEnv::SetReward(float value)
{
NS_LOG_FUNCTION (this);
m_reward = value;
}
void
TcpTimeStepGymEnv::SetPenalty(float value)
{
NS_LOG_FUNCTION (this);
m_penalty = value;
}
/*
Define observation space
*/
Ptr<OpenGymSpace>
TcpTimeStepGymEnv::GetObservationSpace()
{
// socket unique ID
// tcp env type: event-based = 0 / time-based = 1
// sim time in us
// node ID
// ssThresh
// cWnd
// segmentSize
// bytesInFlightSum
// bytesInFlightAvg
// segmentsAckedSum
// segmentsAckedAvg
// avgRtt
// minRtt
// avgInterTx
// avgInterRx
// throughput
uint32_t parameterNum = 16;
float low = 0.0;
float high = 1000000000.0;
std::vector<uint32_t> shape = {parameterNum,};
std::string dtype = TypeNameGet<uint64_t> ();
Ptr<OpenGymBoxSpace> box = CreateObject<OpenGymBoxSpace> (low, high, shape, dtype);
NS_LOG_INFO ("MyGetObservationSpace: " << box);
return box;
}
/*
Collect observations
*/
Ptr<OpenGymDataContainer>
TcpTimeStepGymEnv::GetObservation()
{
uint32_t parameterNum = 16;
std::vector<uint32_t> shape = {parameterNum,};
Ptr<OpenGymBoxContainer<uint64_t> > box = CreateObject<OpenGymBoxContainer<uint64_t> >(shape);
box->AddValue(m_socketUuid);
box->AddValue(1);
box->AddValue(Simulator::Now().GetMicroSeconds ());
box->AddValue(m_nodeId);
box->AddValue(m_tcb->m_ssThresh);
box->AddValue(m_tcb->m_cWnd);
box->AddValue(m_tcb->m_segmentSize);
//bytesInFlightSum
uint64_t bytesInFlightSum = std::accumulate(m_bytesInFlight.begin(), m_bytesInFlight.end(), 0);
box->AddValue(bytesInFlightSum);
//bytesInFlightAvg
uint64_t bytesInFlightAvg = 0;
if (m_bytesInFlight.size()) {
bytesInFlightAvg = bytesInFlightSum / m_bytesInFlight.size();
}
box->AddValue(bytesInFlightAvg);
//segmentsAckedSum
uint64_t segmentsAckedSum = std::accumulate(m_segmentsAcked.begin(), m_segmentsAcked.end(), 0);
box->AddValue(segmentsAckedSum);
//segmentsAckedAvg
uint64_t segmentsAckedAvg = 0;
if (m_segmentsAcked.size()) {
segmentsAckedAvg = segmentsAckedSum / m_segmentsAcked.size();
}
box->AddValue(segmentsAckedAvg);
//avgRtt
Time avgRtt = Seconds(0.0);
if(m_rttSampleNum) {
avgRtt = m_rttSum / m_rttSampleNum;
}
box->AddValue(avgRtt.GetMicroSeconds ());
/*---------------------------------------------------------------------------------------------------*/
/*---------------------------------------------------------------------------------------------------*/
/*---------------------------------------------------------------------------------------------------*/
// Update reward based on overall average of avgRtt over all steps so far
// only when agent increases cWnd
// TODO: this is not the right way of doing this.
// place this somewhere else. see TcpEventGymEnv, how they've done it.
if (m_new_cWnd > m_old_cWnd && m_totalAvgRttSum > 0 && avgRtt > 0) {
// when agent increases cWnd
if ((m_totalAvgRttSum / m_totalAvgRttNum) >= avgRtt) {
// give reward for decreasing avgRtt
m_envReward = m_reward;
} else {
// give penalty for increasing avgRtt
m_envReward = m_penalty;
}
} else {
// agent has not increased cWnd
m_envReward = 0;
}
// Update m_totalAvgRtSum and m_totalAvgRttNum
m_totalAvgRttSum += avgRtt;
m_totalAvgRttNum++;
m_old_cWnd = m_new_cWnd;
/*---------------------------------------------------------------------------------------------------*/
/*---------------------------------------------------------------------------------------------------*/
/*---------------------------------------------------------------------------------------------------*/
//m_minRtt
box->AddValue(m_tcb->m_minRtt.GetMicroSeconds ());
//avgInterTx
Time avgInterTx = Seconds(0.0);
if (m_interTxTimeNum) {
avgInterTx = m_interTxTimeSum / m_interTxTimeNum;
}
box->AddValue(avgInterTx.GetMicroSeconds ());
//avgInterRx
Time avgInterRx = Seconds(0.0);
if (m_interRxTimeNum) {
avgInterRx = m_interRxTimeSum / m_interRxTimeNum;
}
box->AddValue(avgInterRx.GetMicroSeconds ());
//throughput bytes/s
float throughput = (segmentsAckedSum * m_tcb->m_segmentSize) / m_timeStep.GetSeconds();
box->AddValue(throughput);
// Print data
NS_LOG_INFO ("MyGetObservation: " << box);
m_bytesInFlight.clear();
m_segmentsAcked.clear();
m_rttSampleNum = 0;
m_rttSum = MicroSeconds (0.0);
m_interTxTimeNum = 0;
m_interTxTimeSum = MicroSeconds (0.0);
m_interRxTimeNum = 0;
m_interRxTimeSum = MicroSeconds (0.0);
return box;
}
void
TcpTimeStepGymEnv::TxPktTrace(Ptr<const Packet>, const TcpHeader&, Ptr<const TcpSocketBase>)
{
NS_LOG_FUNCTION (this);
if ( m_lastPktTxTime > MicroSeconds(0.0) ) {
Time interTxTime = Simulator::Now() - m_lastPktTxTime;
m_interTxTimeSum += interTxTime;
m_interTxTimeNum++;
}
m_lastPktTxTime = Simulator::Now();
}
void
TcpTimeStepGymEnv::RxPktTrace(Ptr<const Packet>, const TcpHeader&, Ptr<const TcpSocketBase>)
{
NS_LOG_FUNCTION (this);
if ( m_lastPktRxTime > MicroSeconds(0.0) ) {
Time interRxTime = Simulator::Now() - m_lastPktRxTime;
m_interRxTimeSum += interRxTime;
m_interRxTimeNum++;
}
m_lastPktRxTime = Simulator::Now();
}
uint32_t
TcpTimeStepGymEnv::GetSsThresh (Ptr<const TcpSocketState> tcb, uint32_t bytesInFlight)
{
NS_LOG_FUNCTION (this);
NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " GetSsThresh, BytesInFlight: " << bytesInFlight);
m_tcb = tcb;
m_bytesInFlight.push_back(bytesInFlight);
if (!m_started) {
m_started = true;
Notify();
ScheduleNextStateRead();
}
// action
return m_new_ssThresh;
}
void
TcpTimeStepGymEnv::IncreaseWindow (Ptr<TcpSocketState> tcb, uint32_t segmentsAcked)
{
NS_LOG_FUNCTION (this);
NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " IncreaseWindow, SegmentsAcked: " << segmentsAcked);
m_tcb = tcb;
m_segmentsAcked.push_back(segmentsAcked);
m_bytesInFlight.push_back(tcb->m_bytesInFlight);
if (!m_started) {
m_started = true;
Notify();
ScheduleNextStateRead();
}
// action
tcb->m_cWnd = m_new_cWnd;
}
void
TcpTimeStepGymEnv::PktsAcked (Ptr<TcpSocketState> tcb, uint32_t segmentsAcked, const Time& rtt)
{
NS_LOG_FUNCTION (this);
NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " PktsAcked, SegmentsAcked: " << segmentsAcked << " Rtt: " << rtt);
m_tcb = tcb;
m_rttSum += rtt;
m_rttSampleNum++;
}
void
TcpTimeStepGymEnv::CongestionStateSet (Ptr<TcpSocketState> tcb, const TcpSocketState::TcpCongState_t newState)
{
NS_LOG_FUNCTION (this);
std::string stateName = GetTcpCongStateName(newState);
NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " CongestionStateSet: " << newState << " " << stateName);
m_tcb = tcb;
}
void
TcpTimeStepGymEnv::CwndEvent (Ptr<TcpSocketState> tcb, const TcpSocketState::TcpCAEvent_t event)
{
NS_LOG_FUNCTION (this);
std::string eventName = GetTcpCAEventName(event);
NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " CwndEvent: " << event << " " << eventName);
}
} // namespace ns3

210
tcp-rl-env.h Normal file
View File

@ -0,0 +1,210 @@
/* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */
/*
* Copyright (c) 2018 Technische Universität Berlin
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 2 as
* published by the Free Software Foundation;
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
*
* Author: Piotr Gawlowicz <gawlowicz@tkn.tu-berlin.de>
*/
#ifndef TCP_RL_ENV_H
#define TCP_RL_ENV_H
#include "ns3/opengym-module.h"
#include "ns3/tcp-socket-base.h"
#include <vector>
namespace ns3 {
class Packet;
class TcpHeader;
class TcpSocketBase;
class Time;
class TcpGymEnv : public OpenGymEnv
{
public:
TcpGymEnv ();
virtual ~TcpGymEnv ();
static TypeId GetTypeId (void);
virtual void DoDispose ();
void SetNodeId(uint32_t id);
void SetSocketUuid(uint32_t id);
std::string GetTcpCongStateName(const TcpSocketState::TcpCongState_t state);
std::string GetTcpCAEventName(const TcpSocketState::TcpCAEvent_t event);
// OpenGym interface
virtual Ptr<OpenGymSpace> GetActionSpace();
virtual bool GetGameOver();
virtual float GetReward();
virtual std::string GetExtraInfo();
virtual bool ExecuteActions(Ptr<OpenGymDataContainer> action);
virtual Ptr<OpenGymSpace> GetObservationSpace() = 0;
virtual Ptr<OpenGymDataContainer> GetObservation() = 0;
// trace packets, e.g. for calculating inter tx/rx time
virtual void TxPktTrace(Ptr<const Packet>, const TcpHeader&, Ptr<const TcpSocketBase>) = 0;
virtual void RxPktTrace(Ptr<const Packet>, const TcpHeader&, Ptr<const TcpSocketBase>) = 0;
// TCP congestion control interface
virtual uint32_t GetSsThresh (Ptr<const TcpSocketState> tcb, uint32_t bytesInFlight) = 0;
virtual void IncreaseWindow (Ptr<TcpSocketState> tcb, uint32_t segmentsAcked) = 0;
// optional functions used to collect obs
virtual void PktsAcked (Ptr<TcpSocketState> tcb, uint32_t segmentsAcked, const Time& rtt) = 0;
virtual void CongestionStateSet (Ptr<TcpSocketState> tcb, const TcpSocketState::TcpCongState_t newState) = 0;
virtual void CwndEvent (Ptr<TcpSocketState> tcb, const TcpSocketState::TcpCAEvent_t event) = 0;
typedef enum
{
GET_SS_THRESH = 0,
INCREASE_WINDOW,
PKTS_ACKED,
CONGESTION_STATE_SET,
CWND_EVENT,
} CalledFunc_t;
protected:
uint32_t m_nodeId;
uint32_t m_socketUuid;
// state
// obs has to be implemented in child class
// game over
bool m_isGameOver;
// reward
float m_envReward;
// extra info
std::string m_info;
// actions
uint32_t m_new_ssThresh;
uint32_t m_new_cWnd;
};
class TcpEventGymEnv : public TcpGymEnv
{
public:
TcpEventGymEnv ();
virtual ~TcpEventGymEnv ();
static TypeId GetTypeId (void);
virtual void DoDispose ();
void SetReward(float value);
void SetPenalty(float value);
// OpenGym interface
virtual Ptr<OpenGymSpace> GetObservationSpace();
Ptr<OpenGymDataContainer> GetObservation();
// trace packets, e.g. for calculating inter tx/rx time
virtual void TxPktTrace(Ptr<const Packet>, const TcpHeader&, Ptr<const TcpSocketBase>);
virtual void RxPktTrace(Ptr<const Packet>, const TcpHeader&, Ptr<const TcpSocketBase>);
// TCP congestion control interface
virtual uint32_t GetSsThresh (Ptr<const TcpSocketState> tcb, uint32_t bytesInFlight);
virtual void IncreaseWindow (Ptr<TcpSocketState> tcb, uint32_t segmentsAcked);
// optional functions used to collect obs
virtual void PktsAcked (Ptr<TcpSocketState> tcb, uint32_t segmentsAcked, const Time& rtt);
virtual void CongestionStateSet (Ptr<TcpSocketState> tcb, const TcpSocketState::TcpCongState_t newState);
virtual void CwndEvent (Ptr<TcpSocketState> tcb, const TcpSocketState::TcpCAEvent_t event);
private:
// state
CalledFunc_t m_calledFunc;
Ptr<const TcpSocketState> m_tcb;
uint32_t m_bytesInFlight;
uint32_t m_segmentsAcked;
Time m_rtt;
TcpSocketState::TcpCongState_t m_newState;
TcpSocketState::TcpCAEvent_t m_event;
// reward
float m_reward;
float m_penalty;
};
class TcpTimeStepGymEnv : public TcpGymEnv
{
public:
TcpTimeStepGymEnv ();
virtual ~TcpTimeStepGymEnv ();
static TypeId GetTypeId (void);
virtual void DoDispose ();
void SetDuration(Time value);
void SetTimeStep(Time value);
void SetReward(float value);
void SetPenalty(float value);
// OpenGym interface
virtual Ptr<OpenGymSpace> GetObservationSpace();
Ptr<OpenGymDataContainer> GetObservation();
// trace packets, e.g. for calculating inter tx/rx time
virtual void TxPktTrace(Ptr<const Packet>, const TcpHeader&, Ptr<const TcpSocketBase>);
virtual void RxPktTrace(Ptr<const Packet>, const TcpHeader&, Ptr<const TcpSocketBase>);
// TCP congestion control interface
virtual uint32_t GetSsThresh (Ptr<const TcpSocketState> tcb, uint32_t bytesInFlight);
virtual void IncreaseWindow (Ptr<TcpSocketState> tcb, uint32_t segmentsAcked);
// optional functions used to collect obs
virtual void PktsAcked (Ptr<TcpSocketState> tcb, uint32_t segmentsAcked, const Time& rtt);
virtual void CongestionStateSet (Ptr<TcpSocketState> tcb, const TcpSocketState::TcpCongState_t newState);
virtual void CwndEvent (Ptr<TcpSocketState> tcb, const TcpSocketState::TcpCAEvent_t event);
private:
void ScheduleNextStateRead();
bool m_started {false};
Time m_duration;
Time m_timeStep;
// state
Ptr<const TcpSocketState> m_tcb;
std::vector<uint32_t> m_bytesInFlight;
std::vector<uint32_t> m_segmentsAcked;
uint64_t m_rttSampleNum {0};
Time m_rttSum {MicroSeconds (0.0)};
Time m_lastPktTxTime {MicroSeconds(0.0)};
Time m_lastPktRxTime {MicroSeconds(0.0)};
uint64_t m_interTxTimeNum {0};
Time m_interTxTimeSum {MicroSeconds (0.0)};
uint64_t m_interRxTimeNum {0};
Time m_interRxTimeSum {MicroSeconds (0.0)};
Time m_prevAvgRtt {MicroSeconds (0.0)};
Time m_totalAvgRttSum {MicroSeconds (0.0)};
uint64_t m_totalAvgRttNum {0};
uint32_t m_old_cWnd {0};
// reward
float m_reward;
float m_penalty;
};
} // namespace ns3
#endif /* TCP_RL_ENV_H */

382
tcp-rl.cc Normal file
View File

@ -0,0 +1,382 @@
/* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */
/*
* Copyright (c) 2018 Technische Universität Berlin
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 2 as
* published by the Free Software Foundation;
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
*
* Author: Piotr Gawlowicz <gawlowicz@tkn.tu-berlin.de>
*/
#include "tcp-rl.h"
#include "tcp-rl-env.h"
#include "ns3/tcp-header.h"
#include "ns3/object.h"
#include "ns3/node-list.h"
#include "ns3/core-module.h"
#include "ns3/log.h"
#include "ns3/simulator.h"
#include "ns3/tcp-socket-base.h"
#include "ns3/tcp-l4-protocol.h"
namespace ns3 {
NS_OBJECT_ENSURE_REGISTERED (TcpSocketDerived);
TypeId
TcpSocketDerived::GetTypeId (void)
{
static TypeId tid = TypeId ("ns3::TcpSocketDerived")
.SetParent<TcpSocketBase> ()
.SetGroupName ("Internet")
.AddConstructor<TcpSocketDerived> ()
;
return tid;
}
TypeId
TcpSocketDerived::GetInstanceTypeId () const
{
return TcpSocketDerived::GetTypeId ();
}
TcpSocketDerived::TcpSocketDerived (void)
{
}
Ptr<TcpCongestionOps>
TcpSocketDerived::GetCongestionControlAlgorithm ()
{
return m_congestionControl;
}
TcpSocketDerived::~TcpSocketDerived (void)
{
}
NS_LOG_COMPONENT_DEFINE ("ns3::TcpRlBase");
NS_OBJECT_ENSURE_REGISTERED (TcpRlBase);
TypeId
TcpRlBase::GetTypeId (void)
{
static TypeId tid = TypeId ("ns3::TcpRlBase")
.SetParent<TcpCongestionOps> ()
.SetGroupName ("Internet")
.AddConstructor<TcpRlBase> ()
;
return tid;
}
TcpRlBase::TcpRlBase (void)
: TcpCongestionOps ()
{
NS_LOG_FUNCTION (this);
m_tcpSocket = 0;
m_tcpGymEnv = 0;
}
TcpRlBase::TcpRlBase (const TcpRlBase& sock)
: TcpCongestionOps (sock)
{
NS_LOG_FUNCTION (this);
m_tcpSocket = 0;
m_tcpGymEnv = 0;
}
TcpRlBase::~TcpRlBase (void)
{
m_tcpSocket = 0;
m_tcpGymEnv = 0;
}
uint64_t
TcpRlBase::GenerateUuid ()
{
static uint64_t uuid = 0;
uuid++;
return uuid;
}
void
TcpRlBase::CreateGymEnv()
{
NS_LOG_FUNCTION (this);
// should never be called, only child classes: TcpRl and TcpRlTimeBased
}
void
TcpRlBase::ConnectSocketCallbacks()
{
NS_LOG_FUNCTION (this);
bool foundSocket = false;
for (NodeList::Iterator i = NodeList::Begin (); i != NodeList::End (); ++i) {
Ptr<Node> node = *i;
Ptr<TcpL4Protocol> tcp = node->GetObject<TcpL4Protocol> ();
ObjectVectorValue socketVec;
tcp->GetAttribute ("SocketList", socketVec);
NS_LOG_DEBUG("Node: " << node->GetId() << " TCP socket num: " << socketVec.GetN());
uint32_t sockNum = socketVec.GetN();
for (uint32_t j=0; j<sockNum; j++) {
Ptr<Object> sockObj = socketVec.Get(j);
Ptr<TcpSocketBase> tcpSocket = DynamicCast<TcpSocketBase> (sockObj);
NS_LOG_DEBUG("Node: " << node->GetId() << " TCP Socket: " << tcpSocket);
if(!tcpSocket) { continue; }
Ptr<TcpSocketDerived> dtcpSocket = StaticCast<TcpSocketDerived>(tcpSocket);
Ptr<TcpCongestionOps> ca = dtcpSocket->GetCongestionControlAlgorithm();
NS_LOG_DEBUG("CA name: " << ca->GetName());
Ptr<TcpRlBase> rlCa = DynamicCast<TcpRlBase>(ca);
if (rlCa == this) {
NS_LOG_DEBUG("Found TcpRl CA!");
foundSocket = true;
m_tcpSocket = tcpSocket;
break;
}
}
if (foundSocket) {
break;
}
}
NS_ASSERT_MSG(m_tcpSocket, "TCP socket was not found.");
if(m_tcpSocket) {
NS_LOG_DEBUG("Found TCP Socket: " << m_tcpSocket);
m_tcpSocket->TraceConnectWithoutContext ("Tx", MakeCallback (&TcpGymEnv::TxPktTrace, m_tcpGymEnv));
m_tcpSocket->TraceConnectWithoutContext ("Rx", MakeCallback (&TcpGymEnv::RxPktTrace, m_tcpGymEnv));
NS_LOG_DEBUG("Connect socket callbacks " << m_tcpSocket->GetNode()->GetId());
m_tcpGymEnv->SetNodeId(m_tcpSocket->GetNode()->GetId());
}
}
std::string
TcpRlBase::GetName () const
{
return "TcpRlBase";
}
uint32_t
TcpRlBase::GetSsThresh (Ptr<const TcpSocketState> state,
uint32_t bytesInFlight)
{
NS_LOG_FUNCTION (this << state << bytesInFlight);
if (!m_tcpGymEnv) {
CreateGymEnv();
}
uint32_t newSsThresh = 0;
if (m_tcpGymEnv) {
newSsThresh = m_tcpGymEnv->GetSsThresh(state, bytesInFlight);
}
return newSsThresh;
}
void
TcpRlBase::IncreaseWindow (Ptr<TcpSocketState> tcb, uint32_t segmentsAcked)
{
NS_LOG_FUNCTION (this << tcb << segmentsAcked);
if (!m_tcpGymEnv) {
CreateGymEnv();
}
if (m_tcpGymEnv) {
m_tcpGymEnv->IncreaseWindow(tcb, segmentsAcked);
}
}
void
TcpRlBase::PktsAcked (Ptr<TcpSocketState> tcb, uint32_t segmentsAcked, const Time& rtt)
{
NS_LOG_FUNCTION (this);
if (!m_tcpGymEnv) {
CreateGymEnv();
}
if (m_tcpGymEnv) {
m_tcpGymEnv->PktsAcked(tcb, segmentsAcked, rtt);
}
}
void
TcpRlBase::CongestionStateSet (Ptr<TcpSocketState> tcb, const TcpSocketState::TcpCongState_t newState)
{
NS_LOG_FUNCTION (this);
if (!m_tcpGymEnv) {
CreateGymEnv();
}
if (m_tcpGymEnv) {
m_tcpGymEnv->CongestionStateSet(tcb, newState);
}
}
void
TcpRlBase::CwndEvent (Ptr<TcpSocketState> tcb, const TcpSocketState::TcpCAEvent_t event)
{
NS_LOG_FUNCTION (this);
if (!m_tcpGymEnv) {
CreateGymEnv();
}
if (m_tcpGymEnv) {
m_tcpGymEnv->CwndEvent(tcb, event);
}
}
Ptr<TcpCongestionOps>
TcpRlBase::Fork ()
{
return CopyObject<TcpRlBase> (this);
}
NS_OBJECT_ENSURE_REGISTERED (TcpRl);
TypeId
TcpRl::GetTypeId (void)
{
static TypeId tid = TypeId ("ns3::TcpRl")
.SetParent<TcpRlBase> ()
.SetGroupName ("Internet")
.AddConstructor<TcpRl> ()
.AddAttribute ("Reward", "Reward when increasing congestion window.",
DoubleValue (1.0),
MakeDoubleAccessor (&TcpRl::m_reward),
MakeDoubleChecker<double> ())
.AddAttribute ("Penalty", "Reward when increasing congestion window.",
DoubleValue (-10.0),
MakeDoubleAccessor (&TcpRl::m_penalty),
MakeDoubleChecker<double> ())
;
return tid;
}
TcpRl::TcpRl (void)
: TcpRlBase ()
{
NS_LOG_FUNCTION (this);
}
TcpRl::TcpRl (const TcpRl& sock)
: TcpRlBase (sock)
{
NS_LOG_FUNCTION (this);
}
TcpRl::~TcpRl (void)
{
}
std::string
TcpRl::GetName () const
{
return "TcpRl";
}
void
TcpRl::CreateGymEnv()
{
NS_LOG_FUNCTION (this);
Ptr<TcpEventGymEnv> env = CreateObject<TcpEventGymEnv>();
env->SetSocketUuid(TcpRlBase::GenerateUuid());
env->SetReward(m_reward);
env->SetPenalty(m_penalty);
m_tcpGymEnv = env;
ConnectSocketCallbacks();
}
NS_OBJECT_ENSURE_REGISTERED (TcpRlTimeBased);
TypeId
TcpRlTimeBased::GetTypeId (void)
{
static TypeId tid = TypeId ("ns3::TcpRlTimeBased")
.SetParent<TcpRlBase> ()
.SetGroupName ("Internet")
.AddConstructor<TcpRlTimeBased> ()
.AddAttribute ("Duration",
"Simulation Duration. Default: 10000ms",
TimeValue (MilliSeconds (10000)),
MakeTimeAccessor (&TcpRlTimeBased::m_duration),
MakeTimeChecker ())
.AddAttribute ("StepTime",
"Step interval used in TCP env. Default: 100ms",
TimeValue (MilliSeconds (100)),
MakeTimeAccessor (&TcpRlTimeBased::m_timeStep),
MakeTimeChecker ())
.AddAttribute ("Reward", "Reward for increasing congestion window.",
DoubleValue (1.0),
MakeDoubleAccessor (&TcpRlTimeBased::m_reward),
MakeDoubleChecker<double> ())
.AddAttribute ("Penalty", "Penalty for increasing congestion window.",
DoubleValue (-1.0),
MakeDoubleAccessor (&TcpRlTimeBased::m_penalty),
MakeDoubleChecker<double> ())
;
return tid;
}
TcpRlTimeBased::TcpRlTimeBased (void)
: TcpRlBase ()
{
NS_LOG_FUNCTION (this);
}
TcpRlTimeBased::TcpRlTimeBased (const TcpRlTimeBased& sock)
: TcpRlBase (sock)
{
NS_LOG_FUNCTION (this);
}
TcpRlTimeBased::~TcpRlTimeBased (void)
{
}
std::string
TcpRlTimeBased::GetName () const
{
return "TcpRlTimeBased";
}
void
TcpRlTimeBased::CreateGymEnv()
{
NS_LOG_FUNCTION (this);
Ptr<TcpTimeStepGymEnv> env = CreateObject<TcpTimeStepGymEnv> ();
env->SetSocketUuid(TcpRlBase::GenerateUuid());
env->SetDuration(m_duration);
env->SetTimeStep(m_timeStep);
env->SetReward(m_reward);
env->SetPenalty(m_penalty);
m_tcpGymEnv = env;
ConnectSocketCallbacks();
}
} // namespace ns3

127
tcp-rl.h Normal file
View File

@ -0,0 +1,127 @@
/* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */
/*
* Copyright (c) 2018 Technische Universität Berlin
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 2 as
* published by the Free Software Foundation;
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
*
* Author: Piotr Gawlowicz <gawlowicz@tkn.tu-berlin.de>
*/
#ifndef TCP_RL_H
#define TCP_RL_H
#include "ns3/tcp-congestion-ops.h"
#include "ns3/opengym-module.h"
#include "ns3/tcp-socket-base.h"
namespace ns3 {
class TcpSocketBase;
class Time;
class TcpGymEnv;
// used to get pointer to Congestion Algorithm
class TcpSocketDerived : public TcpSocketBase
{
public:
static TypeId GetTypeId (void);
virtual TypeId GetInstanceTypeId () const;
TcpSocketDerived (void);
virtual ~TcpSocketDerived (void);
Ptr<TcpCongestionOps> GetCongestionControlAlgorithm ();
};
class TcpRlBase : public TcpCongestionOps
{
public:
/**
* \brief Get the type ID.
* \return the object TypeId
*/
static TypeId GetTypeId (void);
TcpRlBase ();
/**
* \brief Copy constructor.
* \param sock object to copy.
*/
TcpRlBase (const TcpRlBase& sock);
~TcpRlBase ();
virtual std::string GetName () const;
virtual uint32_t GetSsThresh (Ptr<const TcpSocketState> tcb, uint32_t bytesInFlight);
virtual void IncreaseWindow (Ptr<TcpSocketState> tcb, uint32_t segmentsAcked);
virtual void PktsAcked (Ptr<TcpSocketState> tcb, uint32_t segmentsAcked, const Time& rtt);
virtual void CongestionStateSet (Ptr<TcpSocketState> tcb, const TcpSocketState::TcpCongState_t newState);
virtual void CwndEvent (Ptr<TcpSocketState> tcb, const TcpSocketState::TcpCAEvent_t event);
virtual Ptr<TcpCongestionOps> Fork ();
protected:
static uint64_t GenerateUuid ();
virtual void CreateGymEnv();
void ConnectSocketCallbacks();
// OpenGymEnv interface
Ptr<TcpSocketBase> m_tcpSocket;
Ptr<TcpGymEnv> m_tcpGymEnv;
};
class TcpRl : public TcpRlBase
{
public:
static TypeId GetTypeId (void);
TcpRl ();
TcpRl (const TcpRl& sock);
~TcpRl ();
virtual std::string GetName () const;
private:
virtual void CreateGymEnv();
// OpenGymEnv env
float m_reward {1.0};
float m_penalty {-100.0};
};
class TcpRlTimeBased : public TcpRlBase
{
public:
static TypeId GetTypeId (void);
TcpRlTimeBased ();
TcpRlTimeBased (const TcpRlTimeBased& sock);
~TcpRlTimeBased ();
virtual std::string GetName () const;
private:
virtual void CreateGymEnv();
Time m_duration;
Time m_timeStep;
float m_reward;
float m_penalty;
};
} // namespace ns3
#endif /* TCP_RL_H */

138
tcp_base.py Normal file
View File

@ -0,0 +1,138 @@
__author__ = "Piotr Gawlowicz"
__copyright__ = "Copyright (c) 2018, Technische Universität Berlin"
__version__ = "0.1.0"
__email__ = "gawlowicz@tkn.tu-berlin.de"
class Tcp(object):
"""docstring for Tcp"""
def __init__(self):
super(Tcp, self).__init__()
def set_spaces(self, obs, act):
self.obsSpace = obs
self.actSpace = act
def get_action(self, obs, reward, done, info):
pass
class TcpEventBased(Tcp):
"""docstring for TcpEventBased"""
def __init__(self):
super(TcpEventBased, self).__init__()
def get_action(self, obs, reward, done, info):
# unique socket ID
socketUuid = obs[0]
# TCP env type: event-based = 0 / time-based = 1
envType = obs[1]
# sim time in us
simTime_us = obs[2]
# unique node ID
nodeId = obs[3]
# current ssThreshold
ssThresh = obs[4]
# current contention window size
cWnd = obs[5]
# segment size
segmentSize = obs[6]
# number of acked segments
segmentsAcked = obs[7]
# estimated bytes in flight
bytesInFlight = obs[8]
# last estimation of RTT
lastRtt_us = obs[9]
# min value of RTT
minRtt_us = obs[10]
# function from Congestion Algorithm (CA) interface:
# GET_SS_THRESH = 0 (packet loss),
# INCREASE_WINDOW (packet acked),
# PKTS_ACKED (unused),
# CONGESTION_STATE_SET (unused),
# CWND_EVENT (unused),
calledFunc = obs[11]
# Congetsion Algorithm (CA) state:
# CA_OPEN = 0,
# CA_DISORDER,
# CA_CWR,
# CA_RECOVERY,
# CA_LOSS,
# CA_LAST_STATE
caState = obs[12]
# Congetsion Algorithm (CA) event:
# CA_EVENT_TX_START = 0,
# CA_EVENT_CWND_RESTART,
# CA_EVENT_COMPLETE_CWR,
# CA_EVENT_LOSS,
# CA_EVENT_ECN_NO_CE,
# CA_EVENT_ECN_IS_CE,
# CA_EVENT_DELAYED_ACK,
# CA_EVENT_NON_DELAYED_ACK,
caEvent = obs[13]
# ECN state:
# ECN_DISABLED = 0,
# ECN_IDLE,
# ECN_CE_RCVD,
# ECN_SENDING_ECE,
# ECN_ECE_RCVD,
# ECN_CWR_SENT
ecnState = obs[14]
# compute new values
new_cWnd = 10 * segmentSize
new_ssThresh = 5 * segmentSize
# return actions
actions = [new_ssThresh, new_cWnd]
return actions
class TcpTimeBased(Tcp):
"""docstring for TcpTimeBased"""
def __init__(self):
super(TcpTimeBased, self).__init__()
def get_action(self, obs, reward, done, info):
# unique socket ID
socketUuid = obs[0]
# TCP env type: event-based = 0 / time-based = 1
envType = obs[1]
# sim time in us
simTime_us = obs[2]
# unique node ID
nodeId = obs[3]
# current ssThreshold
ssThresh = obs[4]
# current congestion window size
cWnd = obs[5]
# segment size
segmentSize = obs[6]
# bytesInFlightSum
bytesInFlightSum = obs[7]
# bytesInFlightAvg
bytesInFlightAvg = obs[8]
# segmentsAckedSum
segmentsAckedSum = obs[9]
# segmentsAckedAvg
segmentsAckedAvg = obs[10]
# avgRtt
avgRtt = obs[11]
# minRtt
minRtt = obs[12]
# avgInterTx
avgInterTx = obs[13]
# avgInterRx
avgInterRx = obs[14]
# throughput
throughput = obs[15]
# compute new values
new_cWnd = 10 * segmentSize
new_ssThresh = 5 * segmentSize
# return actions
actions = [new_ssThresh, new_cWnd]
return actions