From 44717b8f3e45ae2de91272b1fd4cbb7b078ca8e1 Mon Sep 17 00:00:00 2001 From: Kaushik Narayan R Date: Sat, 14 May 2022 23:02:03 +0530 Subject: [PATCH] Upload files --- TCP-RL-Agent.py | 258 +++++++++++++++++ sim.cc | 333 ++++++++++++++++++++++ tcp-rl-env.cc | 714 ++++++++++++++++++++++++++++++++++++++++++++++++ tcp-rl-env.h | 210 ++++++++++++++ tcp-rl.cc | 382 ++++++++++++++++++++++++++ tcp-rl.h | 127 +++++++++ tcp_base.py | 138 ++++++++++ 7 files changed, 2162 insertions(+) create mode 100755 TCP-RL-Agent.py create mode 100644 sim.cc create mode 100644 tcp-rl-env.cc create mode 100644 tcp-rl-env.h create mode 100644 tcp-rl.cc create mode 100644 tcp-rl.h create mode 100644 tcp_base.py diff --git a/TCP-RL-Agent.py b/TCP-RL-Agent.py new file mode 100755 index 0000000..9bb202b --- /dev/null +++ b/TCP-RL-Agent.py @@ -0,0 +1,258 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import sys +import argparse + +import numpy as np +import matplotlib as mpl +import matplotlib.pyplot as plt + +import tensorflow as tf + +from ns3gym import ns3env +from tcp_base import TcpTimeBased, TcpEventBased + +try: + w_file = open('run.log', 'w') +except: + w_file = sys.stdout +parser = argparse.ArgumentParser(description='Start simulation script on/off') +parser.add_argument('--start', + type=int, + default=1, + help='Start ns-3 simulation script 0/1, Default: 1') +parser.add_argument('--iterations', + type=int, + default=1, + help='Number of iterations, Default: 1') +parser.add_argument('--steps', + type=int, + default=100, + help='Number of steps, Default 100') +parser.add_argument('--debug', + type=int, + default=0, + help='Show debug output 0/1, Default 0') +args = parser.parse_args() + +startSim = bool(args.start) +iterationNum = int(args.iterations) +maxSteps = int(args.steps) + +port = 5555 +simTime = maxSteps / 10.0 # seconds +stepTime = simTime / 200.0 # seconds +seed = 12 +simArgs = {"--duration": simTime,} + +dashes = "-"*18 +input("[{}Press Enter to start{}]".format(dashes, dashes)) + +# create environment +env = ns3env.Ns3Env(port=port, stepTime=stepTime, startSim=startSim, simSeed=seed, simArgs=simArgs) + +ob_space = env.observation_space +ac_space = env.action_space + +# TODO: right now, the next action is selected inside the loop, rather than using get_action. +# this is because we use the decaying epsilon-greedy algo which needs to use the live model +# somehow change or put that logic in an `RLTCP` class that inherits from the Tcp class, like in tcp_base.py, +# then move the class to tcp_base.py and use that agent here +def get_agent(state): + socketUuid = state[0] + tcpEnvType = state[1] + tcpAgent = get_agent.tcpAgents.get(socketUuid, None) + if tcpAgent is None: + # get a new agent based on the selected env type + if tcpEnvType == 0: + # event-based = 0 + tcpAgent = TcpEventBased() + else: + # time-based = 1 + tcpAgent = TcpTimeBased() + tcpAgent.set_spaces(get_agent.ob_space, get_agent.ac_space) + get_agent.tcpAgents[socketUuid] = tcpAgent + + return tcpAgent + +# initialize agent variables +# (useless until the above todo is fixed) +get_agent.tcpAgents = {} +get_agent.ob_space = ob_space +get_agent.ac_space = ac_space + +def modeler(input_size, output_size): + """ + Designs a fully connected neural network. + """ + model = tf.keras.Sequential() + + # input layer + model.add(tf.keras.layers.Dense((input_size + output_size) // 2, input_shape=(input_size,), activation='relu')) + + # hidden layer of mean size of input and output + # model.add(tf.keras.layers.Dense((input_size + output_size) // 2, activation='relu')) + + # output layer + # maps previous layer of input_size units to output_size units + # this is a classifier network + model.add(tf.keras.layers.Dense(output_size, activation='softmax')) + + return model + +state_size = ob_space.shape[0] - 4 # ignoring 4 env attributes + +action_size = 3 +action_mapping = {} # dict faster than list +action_mapping[0] = 0 +action_mapping[1] = 600 +action_mapping[2] = -150 + +# build model +model = modeler(state_size, action_size) +model.compile( + optimizer = tf.keras.optimizers.Adam(learning_rate=1e-2), + loss='categorical_crossentropy', + metrics=['accuracy'] +) +model.summary() + +# initialize decaying epsilon-greedy algorithm +# fine-tune to ensure balance of exploration and exploitation +epsilon = 1.0 +epsilon_decay_param = iterationNum * 5 +min_epsilon = 0.1 +epsilon_decay = (((epsilon_decay_param*maxSteps) - 1.0) / (epsilon_decay_param*maxSteps)) + +# initialize Q-learning's discount factor +discount_factor = 0.95 + +rewardsum = 0 +rew_history = [] +cWnd_history = [] +pred_cWnd_history = [] +rtt_history = [] + +done = False + +pretty_slash = ['\\', '|', '/', '-'] + +for iteration in range(iterationNum): + # set initial state + state = env.reset() + # ignore env attributes: socketID, env type, sim time, nodeID + state = state[4:] + + cWnd = state[1] + init_cWnd = cWnd + + state = np.reshape(state, [1, state_size]) + try: + for step in range(maxSteps): + pretty_index = step % 4 + print("\r{}\r[{}] Logging to file {} {}".format( + ' '*(25+len(w_file.name)), + pretty_slash[pretty_index], + w_file.name, + '.'*(pretty_index+1) + ), end='') + + print("[+] Step: {}".format(step+1), file=w_file) + + # Epsilon-greedy selection + if step == 0 or np.random.rand(1) < epsilon: + # explore new situation + action_index = np.random.randint(0, action_size) + print("\t[*] Random exploration. Selected action: {}".format(action_index), file=w_file) + else: + # exploit gained knowledge + action_index = np.argmax(model.predict(state)[0]) + print("\t[*] Exploiting gained knowledge. Selected action: {}".format(action_index), file=w_file) + + # Calculate action + # Note: prevent new_cWnd from falling too low to avoid negative values + new_cWnd = cWnd + action_mapping[action_index] + new_ssThresh = int(cWnd/2) + actions = [new_ssThresh, new_cWnd] + + # Take action step on environment and get feedback + next_state, reward, done, _ = env.step(actions) + + rewardsum += reward + + next_state = next_state[4:] + cWnd = next_state[1] + rtt = next_state[7] + + print("\t[#] Next state: ", next_state, file=w_file) + print("\t[!] Reward: ", reward, file=w_file) + + next_state = np.reshape(next_state, [1, state_size]) + + + # Train incrementally + # DQN - function approximation using neural networks + target = reward + if not done: + target = (reward + discount_factor * np.amax(model.predict(next_state)[0])) + target_f = model.predict(state) + target_f[0][action_index] = target + model.fit(state, target_f, epochs=1, verbose=0) + + # Update state + state = next_state + + if done: + print("[X] Stopping: step: {}, reward sum: {}, epsilon: {:.2}" + .format(step+1, rewardsum, epsilon), + file=w_file) + break + + if epsilon > min_epsilon: + epsilon *= epsilon_decay + + # Record information + rew_history.append(rewardsum) + rtt_history.append(rtt) + cWnd_history.append(cWnd) + pred_cWnd_history.append(new_cWnd) + + print("\n[O] Iteration over.", file=w_file) + print("[-] Final epsilon value: ", epsilon, file=w_file) + print("[-] Final reward sum: ", rewardsum, file=w_file) + print() + + finally: + print() + if iteration+1 == iterationNum: + break + # if str(input("[?] Continue to next iteration? [Y/n]: ") or "Y").lower() != "y": + # break + +mpl.rcdefaults() +mpl.rcParams.update({'font.size': 12}) +fig, ax = plt.subplots(2, 2, figsize=(4,2)) +plt.tight_layout(pad=0.3) + +ax[0, 0].plot(range(len(cWnd_history)), cWnd_history, marker="", linestyle="-") +ax[0, 0].set_title('Congestion windows') +ax[0, 0].set_xlabel('Steps') +ax[0, 0].set_ylabel('Actual CWND') + +ax[0, 1].plot(range(len(pred_cWnd_history)), pred_cWnd_history, marker="", linestyle="-") +ax[0, 1].set_title('Predicted values') +ax[0, 1].set_xlabel('Steps') +ax[0, 1].set_ylabel('Predicted CWND') + +ax[1, 0].plot(range(len(rtt_history)), rtt_history, marker="", linestyle="-") +ax[1, 0].set_title('RTT over time') +ax[1, 0].set_xlabel('Steps') +ax[1, 0].set_ylabel('RTT (microseconds)') + +ax[1, 1].plot(range(len(rew_history)), rew_history, marker="", linestyle="-") +ax[1, 1].set_title('Reward sum plot') +ax[1, 1].set_xlabel('Steps') +ax[1, 1].set_ylabel('Reward sum') + +plt.show() + diff --git a/sim.cc b/sim.cc new file mode 100644 index 0000000..1952e8e --- /dev/null +++ b/sim.cc @@ -0,0 +1,333 @@ +/* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ +/* + * Copyright (c) 2018 Piotr Gawlowicz + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 2 as + * published by the Free Software Foundation; + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Piotr Gawlowicz + * Based on script: ./examples/tcp/tcp-variants-comparison.cc + * + * Topology: + * + * Right Leafs (Clients) Left Leafs (Sinks) + * | \ / | + * | \ bottleneck / | + * | R0--------------R1 | + * | / \ | + * | access / \ access | + * N ----------- --------N + */ + +#include +#include +#include + +#include "ns3/core-module.h" +#include "ns3/network-module.h" +#include "ns3/internet-module.h" +#include "ns3/point-to-point-module.h" +#include "ns3/point-to-point-layout-module.h" +#include "ns3/applications-module.h" +#include "ns3/error-model.h" +#include "ns3/tcp-header.h" +#include "ns3/enum.h" +#include "ns3/event-id.h" +#include "ns3/flow-monitor-helper.h" +#include "ns3/ipv4-global-routing-helper.h" +#include "ns3/traffic-control-module.h" + +#include "ns3/opengym-module.h" +#include "tcp-rl.h" + +using namespace ns3; + +NS_LOG_COMPONENT_DEFINE ("TcpVariantsComparison"); + +static std::vector rxPkts; + +static void +CountRxPkts(uint32_t sinkId, Ptr packet, const Address & srcAddr) +{ + rxPkts[sinkId]++; +} + +static void +PrintRxCount() +{ + uint32_t size = rxPkts.size(); + NS_LOG_UNCOND("RxPkts:"); + for (uint32_t i=0; i openGymInterface; + if (transport_prot.compare ("ns3::TcpRl") == 0) + { + openGymInterface = OpenGymInterface::Get(openGymPort); + Config::SetDefault ("ns3::TcpRl::Reward", DoubleValue (2.0)); // Reward when increasing congestion window + Config::SetDefault ("ns3::TcpRl::Penalty", DoubleValue (-30.0)); // Penalty when decreasing congestion window + } + + if (transport_prot.compare ("ns3::TcpRlTimeBased") == 0) + { + openGymInterface = OpenGymInterface::Get(openGymPort); + Config::SetDefault ("ns3::TcpRlTimeBased::StepTime", TimeValue (Seconds(tcpEnvTimeStep))); // Time step of env + Config::SetDefault ("ns3::TcpRlTimeBased::Duration", TimeValue (Seconds(duration))); // Duration of env sim + Config::SetDefault ("ns3::TcpRlTimeBased::Reward", DoubleValue (1.0)); // Reward + Config::SetDefault ("ns3::TcpRlTimeBased::Penalty", DoubleValue (-1.0)); // Penalty + } + + // Calculate the ADU size + Header* temp_header = new Ipv4Header (); + uint32_t ip_header = temp_header->GetSerializedSize (); + NS_LOG_LOGIC ("IP Header size is: " << ip_header); + delete temp_header; + temp_header = new TcpHeader (); + uint32_t tcp_header = temp_header->GetSerializedSize (); + NS_LOG_LOGIC ("TCP Header size is: " << tcp_header); + delete temp_header; + uint32_t tcp_adu_size = mtu_bytes - 20 - (ip_header + tcp_header); + NS_LOG_LOGIC ("TCP ADU size is: " << tcp_adu_size); + + // Set the simulation start and stop time + double start_time = 0.1; // it takes some time to initialise some variables, idk why + double stop_time = start_time + duration; + + // 4 MB of TCP buffer + Config::SetDefault ("ns3::TcpSocket::RcvBufSize", UintegerValue (1 << 21)); + Config::SetDefault ("ns3::TcpSocket::SndBufSize", UintegerValue (1 << 21)); + Config::SetDefault ("ns3::TcpSocketBase::Sack", BooleanValue (sack)); + // no. of packets received before an ACK is sent (why is the default 2?) + Config::SetDefault ("ns3::TcpSocket::DelAckCount", UintegerValue (2)); + + + Config::SetDefault ("ns3::TcpL4Protocol::RecoveryType", + TypeIdValue (TypeId::LookupByName (recovery))); + // Select TCP variant + if (transport_prot.compare ("ns3::TcpWestwoodPlus") == 0) + { + // TcpWestwoodPlus is not an actual TypeId name; we need TcpWestwood here + Config::SetDefault ("ns3::TcpL4Protocol::SocketType", TypeIdValue (TcpWestwood::GetTypeId ())); + // the default protocol type in ns3::TcpWestwood is WESTWOOD + Config::SetDefault ("ns3::TcpWestwood::ProtocolType", EnumValue (TcpWestwood::WESTWOODPLUS)); + } + else + { + TypeId tcpTid; + NS_ABORT_MSG_UNLESS (TypeId::LookupByNameFailSafe (transport_prot, &tcpTid), "TypeId " << transport_prot << " not found"); + Config::SetDefault ("ns3::TcpL4Protocol::SocketType", TypeIdValue (TypeId::LookupByName (transport_prot))); + } + + // Configure the error model + // Here we use RateErrorModel with packet error rate + Ptr uv = CreateObject (); + uv->SetStream (50); + RateErrorModel error_model; + error_model.SetRandomVariable (uv); + error_model.SetUnit (RateErrorModel::ERROR_UNIT_PACKET); + error_model.SetRate (error_p); + + // Create the point-to-point link helpers + PointToPointHelper bottleNeckLink; + bottleNeckLink.SetDeviceAttribute ("DataRate", StringValue (bottleneck_bandwidth)); + bottleNeckLink.SetChannelAttribute ("Delay", StringValue (bottleneck_delay)); + //bottleNeckLink.SetDeviceAttribute ("ReceiveErrorModel", PointerValue (&error_model)); + + PointToPointHelper pointToPointLeaf; + pointToPointLeaf.SetDeviceAttribute ("DataRate", StringValue (access_bandwidth)); + pointToPointLeaf.SetChannelAttribute ("Delay", StringValue (access_delay)); + + PointToPointDumbbellHelper d (nLeaf, pointToPointLeaf, + nLeaf, pointToPointLeaf, + bottleNeckLink); + + // Install IP stack + InternetStackHelper stack; + stack.InstallAll (); + + // Traffic Control + TrafficControlHelper tchPfifo; + tchPfifo.SetRootQueueDisc ("ns3::PfifoFastQueueDisc"); + + TrafficControlHelper tchCoDel; + tchCoDel.SetRootQueueDisc ("ns3::CoDelQueueDisc"); + + DataRate access_b (access_bandwidth); + DataRate bottle_b (bottleneck_bandwidth); + Time access_d (access_delay); + Time bottle_d (bottleneck_delay); + + uint32_t size = static_cast((std::min (access_b, bottle_b).GetBitRate () / 8) * + ((access_d + bottle_d + access_d) * 2).GetSeconds ()); + + Config::SetDefault ("ns3::PfifoFastQueueDisc::MaxSize", + QueueSizeValue (QueueSize (QueueSizeUnit::PACKETS, size / mtu_bytes))); + Config::SetDefault ("ns3::CoDelQueueDisc::MaxSize", + QueueSizeValue (QueueSize (QueueSizeUnit::BYTES, size))); + + if (queue_disc_type.compare ("ns3::PfifoFastQueueDisc") == 0) + { + tchPfifo.Install (d.GetLeft()->GetDevice(1)); + tchPfifo.Install (d.GetRight()->GetDevice(1)); + } + else if (queue_disc_type.compare ("ns3::CoDelQueueDisc") == 0) + { + tchCoDel.Install (d.GetLeft()->GetDevice(1)); + tchCoDel.Install (d.GetRight()->GetDevice(1)); + } + else + { + NS_FATAL_ERROR ("Queue not recognized. Allowed values are ns3::CoDelQueueDisc or ns3::PfifoFastQueueDisc"); + } + + // Assign IP Addresses + d.AssignIpv4Addresses (Ipv4AddressHelper ("10.1.1.0", "255.255.255.0"), + Ipv4AddressHelper ("10.2.1.0", "255.255.255.0"), + Ipv4AddressHelper ("10.3.1.0", "255.255.255.0")); + + + NS_LOG_INFO ("Initialize Global Routing."); + Ipv4GlobalRoutingHelper::PopulateRoutingTables (); + + // Install apps in left and right nodes + uint16_t port = 50000; + Address sinkLocalAddress (InetSocketAddress (Ipv4Address::GetAny (), port)); + PacketSinkHelper sinkHelper ("ns3::TcpSocketFactory", sinkLocalAddress); + ApplicationContainer sinkApps; + for (uint32_t i = 0; i < d.RightCount (); ++i) + { + sinkHelper.SetAttribute ("Protocol", TypeIdValue (TcpSocketFactory::GetTypeId ())); + sinkApps.Add (sinkHelper.Install (d.GetRight (i))); + } + sinkApps.Start (Seconds (0.0)); + sinkApps.Stop (Seconds (stop_time)); + + for (uint32_t i = 0; i < d.LeftCount (); ++i) + { + // Create an on/off app sending packets to the left side + AddressValue remoteAddress (InetSocketAddress (d.GetRightIpv4Address (i), port)); + Config::SetDefault ("ns3::TcpSocket::SegmentSize", UintegerValue (tcp_adu_size)); + BulkSendHelper ftp ("ns3::TcpSocketFactory", Address ()); + ftp.SetAttribute ("Remote", remoteAddress); + ftp.SetAttribute ("SendSize", UintegerValue (tcp_adu_size)); + ftp.SetAttribute ("MaxBytes", UintegerValue (data_mbytes * 1000000)); + + ApplicationContainer clientApp = ftp.Install (d.GetLeft (i)); + clientApp.Start (Seconds (start_time * i)); // Start after sink + clientApp.Stop (Seconds (stop_time - 3)); // Stop before the sink + } + + // Flow monitor + FlowMonitorHelper flowHelper; + if (flow_monitor) + { + flowHelper.InstallAll (); + } + + // Count RX packets + for (uint32_t i = 0; i < d.RightCount (); ++i) + { + rxPkts.push_back(0); + Ptr pktSink = DynamicCast(sinkApps.Get(i)); + pktSink->TraceConnectWithoutContext ("Rx", MakeBoundCallback (&CountRxPkts, i)); + } + + Simulator::Stop (Seconds (stop_time)); + Simulator::Run (); + + if (flow_monitor) + { + flowHelper.SerializeToXmlFile (prefix_file_name + ".flowmonitor", true, true); + } + + if (transport_prot.compare ("ns3::TcpRl") == 0 or transport_prot.compare ("ns3::TcpRlTimeBased") == 0) + { + openGymInterface->NotifySimulationEnd(); + } + + PrintRxCount(); + Simulator::Destroy (); + return 0; +} diff --git a/tcp-rl-env.cc b/tcp-rl-env.cc new file mode 100644 index 0000000..de6577b --- /dev/null +++ b/tcp-rl-env.cc @@ -0,0 +1,714 @@ +/* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ +/* + * Copyright (c) 2018 Technische Universität Berlin + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 2 as + * published by the Free Software Foundation; + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Piotr Gawlowicz + */ + +#include "tcp-rl-env.h" +#include "ns3/tcp-header.h" +#include "ns3/object.h" +#include "ns3/core-module.h" +#include "ns3/log.h" +#include "ns3/simulator.h" +#include "ns3/tcp-socket-base.h" +#include +#include + + +namespace ns3 { + +NS_LOG_COMPONENT_DEFINE ("ns3::TcpGymEnv"); +NS_OBJECT_ENSURE_REGISTERED (TcpGymEnv); + +TcpGymEnv::TcpGymEnv () +{ + NS_LOG_FUNCTION (this); + SetOpenGymInterface(OpenGymInterface::Get()); +} + +TcpGymEnv::~TcpGymEnv () +{ + NS_LOG_FUNCTION (this); +} + +TypeId +TcpGymEnv::GetTypeId (void) +{ + static TypeId tid = TypeId ("ns3::TcpGymEnv") + .SetParent () + .SetGroupName ("OpenGym") + ; + + return tid; +} + +void +TcpGymEnv::DoDispose () +{ + NS_LOG_FUNCTION (this); +} + +void +TcpGymEnv::SetNodeId(uint32_t id) +{ + NS_LOG_FUNCTION (this); + m_nodeId = id; +} + +void +TcpGymEnv::SetSocketUuid(uint32_t id) +{ + NS_LOG_FUNCTION (this); + m_socketUuid = id; +} + +std::string +TcpGymEnv::GetTcpCongStateName(const TcpSocketState::TcpCongState_t state) +{ + std::string stateName = "UNKNOWN"; + switch(state) { + case TcpSocketState::CA_OPEN: + stateName = "CA_OPEN"; + break; + case TcpSocketState::CA_DISORDER: + stateName = "CA_DISORDER"; + break; + case TcpSocketState::CA_CWR: + stateName = "CA_CWR"; + break; + case TcpSocketState::CA_RECOVERY: + stateName = "CA_RECOVERY"; + break; + case TcpSocketState::CA_LOSS: + stateName = "CA_LOSS"; + break; + case TcpSocketState::CA_LAST_STATE: + stateName = "CA_LAST_STATE"; + break; + default: + stateName = "UNKNOWN"; + break; + } + return stateName; +} + +std::string +TcpGymEnv::GetTcpCAEventName(const TcpSocketState::TcpCAEvent_t event) +{ + std::string eventName = "UNKNOWN"; + switch(event) { + case TcpSocketState::CA_EVENT_TX_START: + eventName = "CA_EVENT_TX_START"; + break; + case TcpSocketState::CA_EVENT_CWND_RESTART: + eventName = "CA_EVENT_CWND_RESTART"; + break; + case TcpSocketState::CA_EVENT_COMPLETE_CWR: + eventName = "CA_EVENT_COMPLETE_CWR"; + break; + case TcpSocketState::CA_EVENT_LOSS: + eventName = "CA_EVENT_LOSS"; + break; + case TcpSocketState::CA_EVENT_ECN_NO_CE: + eventName = "CA_EVENT_ECN_NO_CE"; + break; + case TcpSocketState::CA_EVENT_ECN_IS_CE: + eventName = "CA_EVENT_ECN_IS_CE"; + break; + case TcpSocketState::CA_EVENT_DELAYED_ACK: + eventName = "CA_EVENT_DELAYED_ACK"; + break; + case TcpSocketState::CA_EVENT_NON_DELAYED_ACK: + eventName = "CA_EVENT_NON_DELAYED_ACK"; + break; + default: + eventName = "UNKNOWN"; + break; + } + return eventName; +} + +/* +Define action space +*/ +Ptr +TcpGymEnv::GetActionSpace() +{ + // new_ssThresh + // new_cWnd + uint32_t parameterNum = 2; + float low = 0.0; + float high = 65535; + std::vector shape = {parameterNum,}; + std::string dtype = TypeNameGet (); + + Ptr box = CreateObject (low, high, shape, dtype); + NS_LOG_INFO ("MyGetActionSpace: " << box); + return box; +} + +/* +Define game over condition +*/ +bool +TcpGymEnv::GetGameOver() +{ + m_isGameOver = false; + bool test = false; + static float stepCounter = 0.0; + stepCounter += 1; + if (stepCounter == 10 && test) { + m_isGameOver = true; + } + NS_LOG_INFO ("MyGetGameOver: " << m_isGameOver); + return m_isGameOver; +} + +/* +Define reward function +*/ +float +TcpGymEnv::GetReward() +{ + NS_LOG_INFO("MyGetReward: " << m_envReward); + return m_envReward; +} + +/* +Define extra info. Optional +*/ +std::string +TcpGymEnv::GetExtraInfo() +{ + NS_LOG_INFO("MyGetExtraInfo: " << m_info); + return m_info; +} + +/* +Execute received actions +*/ +bool +TcpGymEnv::ExecuteActions(Ptr action) +{ + Ptr > box = DynamicCast >(action); + m_new_ssThresh = box->GetValue(0); + m_new_cWnd = box->GetValue(1); + + NS_LOG_INFO ("MyExecuteActions: " << action); + return true; +} + + +NS_OBJECT_ENSURE_REGISTERED (TcpEventGymEnv); + +TcpEventGymEnv::TcpEventGymEnv () : TcpGymEnv() +{ + NS_LOG_FUNCTION (this); +} + +TcpEventGymEnv::~TcpEventGymEnv () +{ + NS_LOG_FUNCTION (this); +} + +TypeId +TcpEventGymEnv::GetTypeId (void) +{ + static TypeId tid = TypeId ("ns3::TcpEventGymEnv") + .SetParent () + .SetGroupName ("OpenGym") + .AddConstructor () + ; + + return tid; +} + +void +TcpEventGymEnv::DoDispose () +{ + NS_LOG_FUNCTION (this); +} + +void +TcpEventGymEnv::SetReward(float value) +{ + NS_LOG_FUNCTION (this); + m_reward = value; +} + +void +TcpEventGymEnv::SetPenalty(float value) +{ + NS_LOG_FUNCTION (this); + m_penalty = value; +} + +/* +Define observation space +*/ +Ptr +TcpEventGymEnv::GetObservationSpace() +{ + // socket unique ID + // tcp env type: event-based = 0 / time-based = 1 + // sim time in us + // node ID + // ssThresh + // cWnd + // segmentSize + // segmentsAcked + // bytesInFlight + // rtt in us + // min rtt in us + // called func + // congestion algorithm (CA) state + // CA event + // ECN state + uint32_t parameterNum = 10; + float low = 0.0; + float high = 1000000000.0; + std::vector shape = {parameterNum,}; + std::string dtype = TypeNameGet (); + + Ptr box = CreateObject (low, high, shape, dtype); + NS_LOG_INFO ("MyGetObservationSpace: " << box); + return box; +} + +/* +Collect observations +*/ +Ptr +TcpEventGymEnv::GetObservation() +{ + uint32_t parameterNum = 10; + std::vector shape = {parameterNum,}; + + Ptr > box = CreateObject >(shape); + + box->AddValue(m_socketUuid); + box->AddValue(0); + box->AddValue(Simulator::Now().GetMicroSeconds ()); + box->AddValue(m_nodeId); + box->AddValue(m_tcb->m_ssThresh); + box->AddValue(m_tcb->m_cWnd); + box->AddValue(m_tcb->m_segmentSize); + box->AddValue(m_segmentsAcked); + box->AddValue(m_bytesInFlight); + box->AddValue(m_rtt.GetMicroSeconds ()); + //box->AddValue(m_tcb->m_minRtt.GetMicroSeconds ()); + //box->AddValue(m_calledFunc); + //box->AddValue(m_tcb->m_congState); + //box->AddValue(m_event); + //box->AddValue(m_tcb->m_ecnState); + + // Print data + NS_LOG_INFO ("MyGetObservation: " << box); + return box; +} + +void +TcpEventGymEnv::TxPktTrace(Ptr, const TcpHeader&, Ptr) +{ + NS_LOG_FUNCTION (this); +} + +void +TcpEventGymEnv::RxPktTrace(Ptr, const TcpHeader&, Ptr) +{ + NS_LOG_FUNCTION (this); +} + +uint32_t +TcpEventGymEnv::GetSsThresh (Ptr tcb, uint32_t bytesInFlight) +{ + NS_LOG_FUNCTION (this); + // pkt was lost, so penalty + m_envReward = m_penalty; + + NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " GetSsThresh, BytesInFlight: " << bytesInFlight); + m_calledFunc = CalledFunc_t::GET_SS_THRESH; + m_info = "GetSsThresh"; + m_tcb = tcb; + m_bytesInFlight = bytesInFlight; + Notify(); + return m_new_ssThresh; +} + +void +TcpEventGymEnv::IncreaseWindow (Ptr tcb, uint32_t segmentsAcked) +{ + NS_LOG_FUNCTION (this); + // pkt was acked, so reward + m_envReward = m_reward; + + NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " IncreaseWindow, SegmentsAcked: " << segmentsAcked); + m_calledFunc = CalledFunc_t::INCREASE_WINDOW; + m_info = "IncreaseWindow"; + m_tcb = tcb; + m_segmentsAcked = segmentsAcked; + Notify(); + tcb->m_cWnd = m_new_cWnd; +} + +void +TcpEventGymEnv::PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt) +{ + NS_LOG_FUNCTION (this); + NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " PktsAcked, SegmentsAcked: " << segmentsAcked << " Rtt: " << rtt); + m_calledFunc = CalledFunc_t::PKTS_ACKED; + m_info = "PktsAcked"; + m_tcb = tcb; + m_segmentsAcked = segmentsAcked; + m_rtt = rtt; +} + +void +TcpEventGymEnv::CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState) +{ + NS_LOG_FUNCTION (this); + std::string stateName = GetTcpCongStateName(newState); + NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " CongestionStateSet: " << newState << " " << stateName); + + m_calledFunc = CalledFunc_t::CONGESTION_STATE_SET; + m_info = "CongestionStateSet"; + m_tcb = tcb; + m_newState = newState; +} + +void +TcpEventGymEnv::CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event) +{ + NS_LOG_FUNCTION (this); + std::string eventName = GetTcpCAEventName(event); + NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " CwndEvent: " << event << " " << eventName); + + m_calledFunc = CalledFunc_t::CWND_EVENT; + m_info = "CwndEvent"; + m_tcb = tcb; + m_event = event; +} + + +NS_OBJECT_ENSURE_REGISTERED (TcpTimeStepGymEnv); + +TcpTimeStepGymEnv::TcpTimeStepGymEnv () : TcpGymEnv() +{ + NS_LOG_FUNCTION (this); +} + +void +TcpTimeStepGymEnv::ScheduleNextStateRead () +{ + NS_LOG_FUNCTION (this); + Simulator::Schedule (m_timeStep, &TcpTimeStepGymEnv::ScheduleNextStateRead, this); + Notify(); +} + +TcpTimeStepGymEnv::~TcpTimeStepGymEnv () +{ + NS_LOG_FUNCTION (this); +} + +TypeId +TcpTimeStepGymEnv::GetTypeId (void) +{ + static TypeId tid = TypeId ("ns3::TcpTimeStepGymEnv") + .SetParent () + .SetGroupName ("OpenGym") + .AddConstructor () + ; + + return tid; +} + +void +TcpTimeStepGymEnv::DoDispose () +{ + NS_LOG_FUNCTION (this); +} + +void +TcpTimeStepGymEnv::SetDuration(Time value) +{ + NS_LOG_FUNCTION (this); + m_duration = value; +} + +void +TcpTimeStepGymEnv::SetTimeStep(Time value) +{ + NS_LOG_FUNCTION (this); + m_timeStep = value; +} + +void +TcpTimeStepGymEnv::SetReward(float value) +{ + NS_LOG_FUNCTION (this); + m_reward = value; +} + +void +TcpTimeStepGymEnv::SetPenalty(float value) +{ + NS_LOG_FUNCTION (this); + m_penalty = value; +} + +/* +Define observation space +*/ +Ptr +TcpTimeStepGymEnv::GetObservationSpace() +{ + // socket unique ID + // tcp env type: event-based = 0 / time-based = 1 + // sim time in us + // node ID + // ssThresh + // cWnd + // segmentSize + // bytesInFlightSum + // bytesInFlightAvg + // segmentsAckedSum + // segmentsAckedAvg + // avgRtt + // minRtt + // avgInterTx + // avgInterRx + // throughput + uint32_t parameterNum = 16; + float low = 0.0; + float high = 1000000000.0; + std::vector shape = {parameterNum,}; + std::string dtype = TypeNameGet (); + + Ptr box = CreateObject (low, high, shape, dtype); + NS_LOG_INFO ("MyGetObservationSpace: " << box); + return box; +} + +/* +Collect observations +*/ +Ptr +TcpTimeStepGymEnv::GetObservation() +{ + uint32_t parameterNum = 16; + std::vector shape = {parameterNum,}; + + Ptr > box = CreateObject >(shape); + + box->AddValue(m_socketUuid); + box->AddValue(1); + box->AddValue(Simulator::Now().GetMicroSeconds ()); + box->AddValue(m_nodeId); + box->AddValue(m_tcb->m_ssThresh); + box->AddValue(m_tcb->m_cWnd); + box->AddValue(m_tcb->m_segmentSize); + + //bytesInFlightSum + uint64_t bytesInFlightSum = std::accumulate(m_bytesInFlight.begin(), m_bytesInFlight.end(), 0); + box->AddValue(bytesInFlightSum); + + //bytesInFlightAvg + uint64_t bytesInFlightAvg = 0; + if (m_bytesInFlight.size()) { + bytesInFlightAvg = bytesInFlightSum / m_bytesInFlight.size(); + } + box->AddValue(bytesInFlightAvg); + + //segmentsAckedSum + uint64_t segmentsAckedSum = std::accumulate(m_segmentsAcked.begin(), m_segmentsAcked.end(), 0); + box->AddValue(segmentsAckedSum); + + //segmentsAckedAvg + uint64_t segmentsAckedAvg = 0; + if (m_segmentsAcked.size()) { + segmentsAckedAvg = segmentsAckedSum / m_segmentsAcked.size(); + } + box->AddValue(segmentsAckedAvg); + + //avgRtt + Time avgRtt = Seconds(0.0); + if(m_rttSampleNum) { + avgRtt = m_rttSum / m_rttSampleNum; + } + box->AddValue(avgRtt.GetMicroSeconds ()); + +/*---------------------------------------------------------------------------------------------------*/ +/*---------------------------------------------------------------------------------------------------*/ +/*---------------------------------------------------------------------------------------------------*/ + + // Update reward based on overall average of avgRtt over all steps so far + // only when agent increases cWnd + // TODO: this is not the right way of doing this. + // place this somewhere else. see TcpEventGymEnv, how they've done it. + + if (m_new_cWnd > m_old_cWnd && m_totalAvgRttSum > 0 && avgRtt > 0) { + // when agent increases cWnd + if ((m_totalAvgRttSum / m_totalAvgRttNum) >= avgRtt) { + // give reward for decreasing avgRtt + m_envReward = m_reward; + } else { + // give penalty for increasing avgRtt + m_envReward = m_penalty; + } + } else { + // agent has not increased cWnd + m_envReward = 0; + } + + // Update m_totalAvgRtSum and m_totalAvgRttNum + m_totalAvgRttSum += avgRtt; + m_totalAvgRttNum++; + + m_old_cWnd = m_new_cWnd; +/*---------------------------------------------------------------------------------------------------*/ +/*---------------------------------------------------------------------------------------------------*/ +/*---------------------------------------------------------------------------------------------------*/ + + //m_minRtt + box->AddValue(m_tcb->m_minRtt.GetMicroSeconds ()); + + //avgInterTx + Time avgInterTx = Seconds(0.0); + if (m_interTxTimeNum) { + avgInterTx = m_interTxTimeSum / m_interTxTimeNum; + } + box->AddValue(avgInterTx.GetMicroSeconds ()); + + //avgInterRx + Time avgInterRx = Seconds(0.0); + if (m_interRxTimeNum) { + avgInterRx = m_interRxTimeSum / m_interRxTimeNum; + } + box->AddValue(avgInterRx.GetMicroSeconds ()); + + //throughput bytes/s + float throughput = (segmentsAckedSum * m_tcb->m_segmentSize) / m_timeStep.GetSeconds(); + box->AddValue(throughput); + + // Print data + NS_LOG_INFO ("MyGetObservation: " << box); + + m_bytesInFlight.clear(); + m_segmentsAcked.clear(); + + m_rttSampleNum = 0; + m_rttSum = MicroSeconds (0.0); + + m_interTxTimeNum = 0; + m_interTxTimeSum = MicroSeconds (0.0); + + m_interRxTimeNum = 0; + m_interRxTimeSum = MicroSeconds (0.0); + + return box; +} + +void +TcpTimeStepGymEnv::TxPktTrace(Ptr, const TcpHeader&, Ptr) +{ + NS_LOG_FUNCTION (this); + if ( m_lastPktTxTime > MicroSeconds(0.0) ) { + Time interTxTime = Simulator::Now() - m_lastPktTxTime; + m_interTxTimeSum += interTxTime; + m_interTxTimeNum++; + } + + m_lastPktTxTime = Simulator::Now(); +} + +void +TcpTimeStepGymEnv::RxPktTrace(Ptr, const TcpHeader&, Ptr) +{ + NS_LOG_FUNCTION (this); + if ( m_lastPktRxTime > MicroSeconds(0.0) ) { + Time interRxTime = Simulator::Now() - m_lastPktRxTime; + m_interRxTimeSum += interRxTime; + m_interRxTimeNum++; + } + + m_lastPktRxTime = Simulator::Now(); +} + +uint32_t +TcpTimeStepGymEnv::GetSsThresh (Ptr tcb, uint32_t bytesInFlight) +{ + NS_LOG_FUNCTION (this); + NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " GetSsThresh, BytesInFlight: " << bytesInFlight); + m_tcb = tcb; + m_bytesInFlight.push_back(bytesInFlight); + + if (!m_started) { + m_started = true; + Notify(); + ScheduleNextStateRead(); + } + + // action + return m_new_ssThresh; +} + +void +TcpTimeStepGymEnv::IncreaseWindow (Ptr tcb, uint32_t segmentsAcked) +{ + NS_LOG_FUNCTION (this); + NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " IncreaseWindow, SegmentsAcked: " << segmentsAcked); + m_tcb = tcb; + m_segmentsAcked.push_back(segmentsAcked); + m_bytesInFlight.push_back(tcb->m_bytesInFlight); + + if (!m_started) { + m_started = true; + Notify(); + ScheduleNextStateRead(); + } + // action + tcb->m_cWnd = m_new_cWnd; +} + +void +TcpTimeStepGymEnv::PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt) +{ + NS_LOG_FUNCTION (this); + NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " PktsAcked, SegmentsAcked: " << segmentsAcked << " Rtt: " << rtt); + m_tcb = tcb; + m_rttSum += rtt; + m_rttSampleNum++; +} + +void +TcpTimeStepGymEnv::CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState) +{ + NS_LOG_FUNCTION (this); + std::string stateName = GetTcpCongStateName(newState); + NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " CongestionStateSet: " << newState << " " << stateName); + m_tcb = tcb; +} + +void +TcpTimeStepGymEnv::CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event) +{ + NS_LOG_FUNCTION (this); + std::string eventName = GetTcpCAEventName(event); + NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " CwndEvent: " << event << " " << eventName); +} + +} // namespace ns3 diff --git a/tcp-rl-env.h b/tcp-rl-env.h new file mode 100644 index 0000000..45eaff2 --- /dev/null +++ b/tcp-rl-env.h @@ -0,0 +1,210 @@ +/* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ +/* + * Copyright (c) 2018 Technische Universität Berlin + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 2 as + * published by the Free Software Foundation; + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Piotr Gawlowicz + */ + +#ifndef TCP_RL_ENV_H +#define TCP_RL_ENV_H + +#include "ns3/opengym-module.h" +#include "ns3/tcp-socket-base.h" +#include + +namespace ns3 { + +class Packet; +class TcpHeader; +class TcpSocketBase; +class Time; + + +class TcpGymEnv : public OpenGymEnv +{ +public: + TcpGymEnv (); + virtual ~TcpGymEnv (); + static TypeId GetTypeId (void); + virtual void DoDispose (); + + void SetNodeId(uint32_t id); + void SetSocketUuid(uint32_t id); + + std::string GetTcpCongStateName(const TcpSocketState::TcpCongState_t state); + std::string GetTcpCAEventName(const TcpSocketState::TcpCAEvent_t event); + + // OpenGym interface + virtual Ptr GetActionSpace(); + virtual bool GetGameOver(); + virtual float GetReward(); + virtual std::string GetExtraInfo(); + virtual bool ExecuteActions(Ptr action); + + virtual Ptr GetObservationSpace() = 0; + virtual Ptr GetObservation() = 0; + + // trace packets, e.g. for calculating inter tx/rx time + virtual void TxPktTrace(Ptr, const TcpHeader&, Ptr) = 0; + virtual void RxPktTrace(Ptr, const TcpHeader&, Ptr) = 0; + + // TCP congestion control interface + virtual uint32_t GetSsThresh (Ptr tcb, uint32_t bytesInFlight) = 0; + virtual void IncreaseWindow (Ptr tcb, uint32_t segmentsAcked) = 0; + // optional functions used to collect obs + virtual void PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt) = 0; + virtual void CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState) = 0; + virtual void CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event) = 0; + + typedef enum + { + GET_SS_THRESH = 0, + INCREASE_WINDOW, + PKTS_ACKED, + CONGESTION_STATE_SET, + CWND_EVENT, + } CalledFunc_t; + +protected: + uint32_t m_nodeId; + uint32_t m_socketUuid; + + // state + // obs has to be implemented in child class + + // game over + bool m_isGameOver; + + // reward + float m_envReward; + + // extra info + std::string m_info; + + // actions + uint32_t m_new_ssThresh; + uint32_t m_new_cWnd; +}; + + +class TcpEventGymEnv : public TcpGymEnv +{ +public: + TcpEventGymEnv (); + virtual ~TcpEventGymEnv (); + static TypeId GetTypeId (void); + virtual void DoDispose (); + + void SetReward(float value); + void SetPenalty(float value); + + // OpenGym interface + virtual Ptr GetObservationSpace(); + Ptr GetObservation(); + + // trace packets, e.g. for calculating inter tx/rx time + virtual void TxPktTrace(Ptr, const TcpHeader&, Ptr); + virtual void RxPktTrace(Ptr, const TcpHeader&, Ptr); + + // TCP congestion control interface + virtual uint32_t GetSsThresh (Ptr tcb, uint32_t bytesInFlight); + virtual void IncreaseWindow (Ptr tcb, uint32_t segmentsAcked); + // optional functions used to collect obs + virtual void PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt); + virtual void CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState); + virtual void CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event); + +private: + // state + CalledFunc_t m_calledFunc; + Ptr m_tcb; + uint32_t m_bytesInFlight; + uint32_t m_segmentsAcked; + Time m_rtt; + TcpSocketState::TcpCongState_t m_newState; + TcpSocketState::TcpCAEvent_t m_event; + + // reward + float m_reward; + float m_penalty; +}; + + +class TcpTimeStepGymEnv : public TcpGymEnv +{ +public: + TcpTimeStepGymEnv (); + + virtual ~TcpTimeStepGymEnv (); + static TypeId GetTypeId (void); + virtual void DoDispose (); + + void SetDuration(Time value); + void SetTimeStep(Time value); + void SetReward(float value); + void SetPenalty(float value); + + // OpenGym interface + virtual Ptr GetObservationSpace(); + Ptr GetObservation(); + + // trace packets, e.g. for calculating inter tx/rx time + virtual void TxPktTrace(Ptr, const TcpHeader&, Ptr); + virtual void RxPktTrace(Ptr, const TcpHeader&, Ptr); + + // TCP congestion control interface + virtual uint32_t GetSsThresh (Ptr tcb, uint32_t bytesInFlight); + virtual void IncreaseWindow (Ptr tcb, uint32_t segmentsAcked); + // optional functions used to collect obs + virtual void PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt); + virtual void CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState); + virtual void CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event); + +private: + void ScheduleNextStateRead(); + bool m_started {false}; + Time m_duration; + Time m_timeStep; + + // state + Ptr m_tcb; + std::vector m_bytesInFlight; + std::vector m_segmentsAcked; + + uint64_t m_rttSampleNum {0}; + Time m_rttSum {MicroSeconds (0.0)}; + + Time m_lastPktTxTime {MicroSeconds(0.0)}; + Time m_lastPktRxTime {MicroSeconds(0.0)}; + uint64_t m_interTxTimeNum {0}; + Time m_interTxTimeSum {MicroSeconds (0.0)}; + uint64_t m_interRxTimeNum {0}; + Time m_interRxTimeSum {MicroSeconds (0.0)}; + Time m_prevAvgRtt {MicroSeconds (0.0)}; + Time m_totalAvgRttSum {MicroSeconds (0.0)}; + uint64_t m_totalAvgRttNum {0}; + uint32_t m_old_cWnd {0}; + + // reward + float m_reward; + float m_penalty; +}; + + + +} // namespace ns3 + +#endif /* TCP_RL_ENV_H */ diff --git a/tcp-rl.cc b/tcp-rl.cc new file mode 100644 index 0000000..c715d1a --- /dev/null +++ b/tcp-rl.cc @@ -0,0 +1,382 @@ +/* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ +/* + * Copyright (c) 2018 Technische Universität Berlin + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 2 as + * published by the Free Software Foundation; + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Piotr Gawlowicz + */ + +#include "tcp-rl.h" +#include "tcp-rl-env.h" +#include "ns3/tcp-header.h" +#include "ns3/object.h" +#include "ns3/node-list.h" +#include "ns3/core-module.h" +#include "ns3/log.h" +#include "ns3/simulator.h" +#include "ns3/tcp-socket-base.h" +#include "ns3/tcp-l4-protocol.h" + + +namespace ns3 { + + +NS_OBJECT_ENSURE_REGISTERED (TcpSocketDerived); + +TypeId +TcpSocketDerived::GetTypeId (void) +{ + static TypeId tid = TypeId ("ns3::TcpSocketDerived") + .SetParent () + .SetGroupName ("Internet") + .AddConstructor () + ; + return tid; +} + +TypeId +TcpSocketDerived::GetInstanceTypeId () const +{ + return TcpSocketDerived::GetTypeId (); +} + +TcpSocketDerived::TcpSocketDerived (void) +{ +} + +Ptr +TcpSocketDerived::GetCongestionControlAlgorithm () +{ + return m_congestionControl; +} + +TcpSocketDerived::~TcpSocketDerived (void) +{ +} + + +NS_LOG_COMPONENT_DEFINE ("ns3::TcpRlBase"); +NS_OBJECT_ENSURE_REGISTERED (TcpRlBase); + +TypeId +TcpRlBase::GetTypeId (void) +{ + static TypeId tid = TypeId ("ns3::TcpRlBase") + .SetParent () + .SetGroupName ("Internet") + .AddConstructor () + ; + return tid; +} + +TcpRlBase::TcpRlBase (void) + : TcpCongestionOps () +{ + NS_LOG_FUNCTION (this); + m_tcpSocket = 0; + m_tcpGymEnv = 0; +} + +TcpRlBase::TcpRlBase (const TcpRlBase& sock) + : TcpCongestionOps (sock) +{ + NS_LOG_FUNCTION (this); + m_tcpSocket = 0; + m_tcpGymEnv = 0; +} + +TcpRlBase::~TcpRlBase (void) +{ + m_tcpSocket = 0; + m_tcpGymEnv = 0; +} + +uint64_t +TcpRlBase::GenerateUuid () +{ + static uint64_t uuid = 0; + uuid++; + return uuid; +} + +void +TcpRlBase::CreateGymEnv() +{ + NS_LOG_FUNCTION (this); + // should never be called, only child classes: TcpRl and TcpRlTimeBased +} + +void +TcpRlBase::ConnectSocketCallbacks() +{ + NS_LOG_FUNCTION (this); + + bool foundSocket = false; + for (NodeList::Iterator i = NodeList::Begin (); i != NodeList::End (); ++i) { + Ptr node = *i; + Ptr tcp = node->GetObject (); + + ObjectVectorValue socketVec; + tcp->GetAttribute ("SocketList", socketVec); + NS_LOG_DEBUG("Node: " << node->GetId() << " TCP socket num: " << socketVec.GetN()); + + uint32_t sockNum = socketVec.GetN(); + for (uint32_t j=0; j sockObj = socketVec.Get(j); + Ptr tcpSocket = DynamicCast (sockObj); + NS_LOG_DEBUG("Node: " << node->GetId() << " TCP Socket: " << tcpSocket); + if(!tcpSocket) { continue; } + + Ptr dtcpSocket = StaticCast(tcpSocket); + Ptr ca = dtcpSocket->GetCongestionControlAlgorithm(); + NS_LOG_DEBUG("CA name: " << ca->GetName()); + Ptr rlCa = DynamicCast(ca); + if (rlCa == this) { + NS_LOG_DEBUG("Found TcpRl CA!"); + foundSocket = true; + m_tcpSocket = tcpSocket; + break; + } + } + + if (foundSocket) { + break; + } + } + + NS_ASSERT_MSG(m_tcpSocket, "TCP socket was not found."); + + if(m_tcpSocket) { + NS_LOG_DEBUG("Found TCP Socket: " << m_tcpSocket); + m_tcpSocket->TraceConnectWithoutContext ("Tx", MakeCallback (&TcpGymEnv::TxPktTrace, m_tcpGymEnv)); + m_tcpSocket->TraceConnectWithoutContext ("Rx", MakeCallback (&TcpGymEnv::RxPktTrace, m_tcpGymEnv)); + NS_LOG_DEBUG("Connect socket callbacks " << m_tcpSocket->GetNode()->GetId()); + m_tcpGymEnv->SetNodeId(m_tcpSocket->GetNode()->GetId()); + } +} + +std::string +TcpRlBase::GetName () const +{ + return "TcpRlBase"; +} + +uint32_t +TcpRlBase::GetSsThresh (Ptr state, + uint32_t bytesInFlight) +{ + NS_LOG_FUNCTION (this << state << bytesInFlight); + + if (!m_tcpGymEnv) { + CreateGymEnv(); + } + + uint32_t newSsThresh = 0; + if (m_tcpGymEnv) { + newSsThresh = m_tcpGymEnv->GetSsThresh(state, bytesInFlight); + } + + return newSsThresh; +} + +void +TcpRlBase::IncreaseWindow (Ptr tcb, uint32_t segmentsAcked) +{ + NS_LOG_FUNCTION (this << tcb << segmentsAcked); + + if (!m_tcpGymEnv) { + CreateGymEnv(); + } + + if (m_tcpGymEnv) { + m_tcpGymEnv->IncreaseWindow(tcb, segmentsAcked); + } +} + +void +TcpRlBase::PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt) +{ + NS_LOG_FUNCTION (this); + + if (!m_tcpGymEnv) { + CreateGymEnv(); + } + + if (m_tcpGymEnv) { + m_tcpGymEnv->PktsAcked(tcb, segmentsAcked, rtt); + } +} + +void +TcpRlBase::CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState) +{ + NS_LOG_FUNCTION (this); + + if (!m_tcpGymEnv) { + CreateGymEnv(); + } + + if (m_tcpGymEnv) { + m_tcpGymEnv->CongestionStateSet(tcb, newState); + } +} + +void +TcpRlBase::CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event) +{ + NS_LOG_FUNCTION (this); + + if (!m_tcpGymEnv) { + CreateGymEnv(); + } + + if (m_tcpGymEnv) { + m_tcpGymEnv->CwndEvent(tcb, event); + } +} + +Ptr +TcpRlBase::Fork () +{ + return CopyObject (this); +} + + +NS_OBJECT_ENSURE_REGISTERED (TcpRl); + +TypeId +TcpRl::GetTypeId (void) +{ + static TypeId tid = TypeId ("ns3::TcpRl") + .SetParent () + .SetGroupName ("Internet") + .AddConstructor () + .AddAttribute ("Reward", "Reward when increasing congestion window.", + DoubleValue (1.0), + MakeDoubleAccessor (&TcpRl::m_reward), + MakeDoubleChecker ()) + .AddAttribute ("Penalty", "Reward when increasing congestion window.", + DoubleValue (-10.0), + MakeDoubleAccessor (&TcpRl::m_penalty), + MakeDoubleChecker ()) + ; + return tid; +} + +TcpRl::TcpRl (void) + : TcpRlBase () +{ + NS_LOG_FUNCTION (this); +} + +TcpRl::TcpRl (const TcpRl& sock) + : TcpRlBase (sock) +{ + NS_LOG_FUNCTION (this); +} + +TcpRl::~TcpRl (void) +{ +} + +std::string +TcpRl::GetName () const +{ + return "TcpRl"; +} + +void +TcpRl::CreateGymEnv() +{ + NS_LOG_FUNCTION (this); + Ptr env = CreateObject(); + env->SetSocketUuid(TcpRlBase::GenerateUuid()); + env->SetReward(m_reward); + env->SetPenalty(m_penalty); + m_tcpGymEnv = env; + + ConnectSocketCallbacks(); +} + + +NS_OBJECT_ENSURE_REGISTERED (TcpRlTimeBased); + +TypeId +TcpRlTimeBased::GetTypeId (void) +{ + static TypeId tid = TypeId ("ns3::TcpRlTimeBased") + .SetParent () + .SetGroupName ("Internet") + .AddConstructor () + .AddAttribute ("Duration", + "Simulation Duration. Default: 10000ms", + TimeValue (MilliSeconds (10000)), + MakeTimeAccessor (&TcpRlTimeBased::m_duration), + MakeTimeChecker ()) + .AddAttribute ("StepTime", + "Step interval used in TCP env. Default: 100ms", + TimeValue (MilliSeconds (100)), + MakeTimeAccessor (&TcpRlTimeBased::m_timeStep), + MakeTimeChecker ()) + .AddAttribute ("Reward", "Reward for increasing congestion window.", + DoubleValue (1.0), + MakeDoubleAccessor (&TcpRlTimeBased::m_reward), + MakeDoubleChecker ()) + .AddAttribute ("Penalty", "Penalty for increasing congestion window.", + DoubleValue (-1.0), + MakeDoubleAccessor (&TcpRlTimeBased::m_penalty), + MakeDoubleChecker ()) + ; + return tid; +} + +TcpRlTimeBased::TcpRlTimeBased (void) + : TcpRlBase () +{ + NS_LOG_FUNCTION (this); +} + +TcpRlTimeBased::TcpRlTimeBased (const TcpRlTimeBased& sock) + : TcpRlBase (sock) +{ + NS_LOG_FUNCTION (this); +} + +TcpRlTimeBased::~TcpRlTimeBased (void) +{ +} + +std::string +TcpRlTimeBased::GetName () const +{ + return "TcpRlTimeBased"; +} + +void +TcpRlTimeBased::CreateGymEnv() +{ + NS_LOG_FUNCTION (this); + Ptr env = CreateObject (); + env->SetSocketUuid(TcpRlBase::GenerateUuid()); + env->SetDuration(m_duration); + env->SetTimeStep(m_timeStep); + env->SetReward(m_reward); + env->SetPenalty(m_penalty); + m_tcpGymEnv = env; + + ConnectSocketCallbacks(); +} + +} // namespace ns3 diff --git a/tcp-rl.h b/tcp-rl.h new file mode 100644 index 0000000..317d0b5 --- /dev/null +++ b/tcp-rl.h @@ -0,0 +1,127 @@ +/* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ +/* + * Copyright (c) 2018 Technische Universität Berlin + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 2 as + * published by the Free Software Foundation; + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Piotr Gawlowicz + */ + +#ifndef TCP_RL_H +#define TCP_RL_H + +#include "ns3/tcp-congestion-ops.h" +#include "ns3/opengym-module.h" +#include "ns3/tcp-socket-base.h" + +namespace ns3 { + +class TcpSocketBase; +class Time; +class TcpGymEnv; + + +// used to get pointer to Congestion Algorithm +class TcpSocketDerived : public TcpSocketBase +{ +public: + static TypeId GetTypeId (void); + virtual TypeId GetInstanceTypeId () const; + + TcpSocketDerived (void); + virtual ~TcpSocketDerived (void); + + Ptr GetCongestionControlAlgorithm (); +}; + + +class TcpRlBase : public TcpCongestionOps +{ +public: + /** + * \brief Get the type ID. + * \return the object TypeId + */ + static TypeId GetTypeId (void); + + TcpRlBase (); + + /** + * \brief Copy constructor. + * \param sock object to copy. + */ + TcpRlBase (const TcpRlBase& sock); + + ~TcpRlBase (); + + virtual std::string GetName () const; + virtual uint32_t GetSsThresh (Ptr tcb, uint32_t bytesInFlight); + virtual void IncreaseWindow (Ptr tcb, uint32_t segmentsAcked); + virtual void PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt); + virtual void CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState); + virtual void CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event); + virtual Ptr Fork (); + +protected: + static uint64_t GenerateUuid (); + virtual void CreateGymEnv(); + void ConnectSocketCallbacks(); + + // OpenGymEnv interface + Ptr m_tcpSocket; + Ptr m_tcpGymEnv; +}; + + +class TcpRl : public TcpRlBase +{ +public: + static TypeId GetTypeId (void); + + TcpRl (); + TcpRl (const TcpRl& sock); + ~TcpRl (); + + virtual std::string GetName () const; +private: + virtual void CreateGymEnv(); + // OpenGymEnv env + float m_reward {1.0}; + float m_penalty {-100.0}; +}; + + +class TcpRlTimeBased : public TcpRlBase +{ +public: + static TypeId GetTypeId (void); + + TcpRlTimeBased (); + TcpRlTimeBased (const TcpRlTimeBased& sock); + ~TcpRlTimeBased (); + + virtual std::string GetName () const; + +private: + virtual void CreateGymEnv(); + + Time m_duration; + Time m_timeStep; + float m_reward; + float m_penalty; +}; + +} // namespace ns3 + +#endif /* TCP_RL_H */ \ No newline at end of file diff --git a/tcp_base.py b/tcp_base.py new file mode 100644 index 0000000..b81c8ee --- /dev/null +++ b/tcp_base.py @@ -0,0 +1,138 @@ +__author__ = "Piotr Gawlowicz" +__copyright__ = "Copyright (c) 2018, Technische Universität Berlin" +__version__ = "0.1.0" +__email__ = "gawlowicz@tkn.tu-berlin.de" + + +class Tcp(object): + """docstring for Tcp""" + def __init__(self): + super(Tcp, self).__init__() + + def set_spaces(self, obs, act): + self.obsSpace = obs + self.actSpace = act + + def get_action(self, obs, reward, done, info): + pass + + +class TcpEventBased(Tcp): + """docstring for TcpEventBased""" + def __init__(self): + super(TcpEventBased, self).__init__() + + def get_action(self, obs, reward, done, info): + # unique socket ID + socketUuid = obs[0] + # TCP env type: event-based = 0 / time-based = 1 + envType = obs[1] + # sim time in us + simTime_us = obs[2] + # unique node ID + nodeId = obs[3] + # current ssThreshold + ssThresh = obs[4] + # current contention window size + cWnd = obs[5] + # segment size + segmentSize = obs[6] + # number of acked segments + segmentsAcked = obs[7] + # estimated bytes in flight + bytesInFlight = obs[8] + # last estimation of RTT + lastRtt_us = obs[9] + # min value of RTT + minRtt_us = obs[10] + # function from Congestion Algorithm (CA) interface: + # GET_SS_THRESH = 0 (packet loss), + # INCREASE_WINDOW (packet acked), + # PKTS_ACKED (unused), + # CONGESTION_STATE_SET (unused), + # CWND_EVENT (unused), + calledFunc = obs[11] + # Congetsion Algorithm (CA) state: + # CA_OPEN = 0, + # CA_DISORDER, + # CA_CWR, + # CA_RECOVERY, + # CA_LOSS, + # CA_LAST_STATE + caState = obs[12] + # Congetsion Algorithm (CA) event: + # CA_EVENT_TX_START = 0, + # CA_EVENT_CWND_RESTART, + # CA_EVENT_COMPLETE_CWR, + # CA_EVENT_LOSS, + # CA_EVENT_ECN_NO_CE, + # CA_EVENT_ECN_IS_CE, + # CA_EVENT_DELAYED_ACK, + # CA_EVENT_NON_DELAYED_ACK, + caEvent = obs[13] + # ECN state: + # ECN_DISABLED = 0, + # ECN_IDLE, + # ECN_CE_RCVD, + # ECN_SENDING_ECE, + # ECN_ECE_RCVD, + # ECN_CWR_SENT + ecnState = obs[14] + + # compute new values + new_cWnd = 10 * segmentSize + new_ssThresh = 5 * segmentSize + + # return actions + actions = [new_ssThresh, new_cWnd] + + return actions + + +class TcpTimeBased(Tcp): + """docstring for TcpTimeBased""" + def __init__(self): + super(TcpTimeBased, self).__init__() + + def get_action(self, obs, reward, done, info): + # unique socket ID + socketUuid = obs[0] + # TCP env type: event-based = 0 / time-based = 1 + envType = obs[1] + # sim time in us + simTime_us = obs[2] + # unique node ID + nodeId = obs[3] + # current ssThreshold + ssThresh = obs[4] + # current congestion window size + cWnd = obs[5] + # segment size + segmentSize = obs[6] + # bytesInFlightSum + bytesInFlightSum = obs[7] + # bytesInFlightAvg + bytesInFlightAvg = obs[8] + # segmentsAckedSum + segmentsAckedSum = obs[9] + # segmentsAckedAvg + segmentsAckedAvg = obs[10] + # avgRtt + avgRtt = obs[11] + # minRtt + minRtt = obs[12] + # avgInterTx + avgInterTx = obs[13] + # avgInterRx + avgInterRx = obs[14] + # throughput + throughput = obs[15] + + # compute new values + new_cWnd = 10 * segmentSize + new_ssThresh = 5 * segmentSize + + # return actions + actions = [new_ssThresh, new_cWnd] + + return actions