#include "tcp-rl.h" #include "tcp-rl-env.h" #include "ns3/tcp-header.h" #include "ns3/object.h" #include "ns3/node-list.h" #include "ns3/core-module.h" #include "ns3/log.h" #include "ns3/simulator.h" #include "ns3/tcp-socket-base.h" #include "ns3/tcp-l4-protocol.h" namespace ns3 { NS_OBJECT_ENSURE_REGISTERED (TcpSocketDerived); TypeId TcpSocketDerived::GetTypeId (void) { static TypeId tid = TypeId ("ns3::TcpSocketDerived") .SetParent () .SetGroupName ("Internet") .AddConstructor () ; return tid; } TypeId TcpSocketDerived::GetInstanceTypeId () const { return TcpSocketDerived::GetTypeId (); } TcpSocketDerived::TcpSocketDerived (void) { } Ptr TcpSocketDerived::GetCongestionControlAlgorithm () { return m_congestionControl; } TcpSocketDerived::~TcpSocketDerived (void) { } NS_LOG_COMPONENT_DEFINE ("ns3::TcpRlBase"); NS_OBJECT_ENSURE_REGISTERED (TcpRlBase); TypeId TcpRlBase::GetTypeId (void) { static TypeId tid = TypeId ("ns3::TcpRlBase") .SetParent () .SetGroupName ("Internet") .AddConstructor () ; return tid; } TcpRlBase::TcpRlBase (void) : TcpCongestionOps () { NS_LOG_FUNCTION (this); m_tcpSocket = 0; m_tcpGymEnv = 0; } TcpRlBase::TcpRlBase (const TcpRlBase& sock) : TcpCongestionOps (sock) { NS_LOG_FUNCTION (this); m_tcpSocket = 0; m_tcpGymEnv = 0; } TcpRlBase::~TcpRlBase (void) { m_tcpSocket = 0; m_tcpGymEnv = 0; } uint64_t TcpRlBase::GenerateUuid () { static uint64_t uuid = 0; uuid++; return uuid; } void TcpRlBase::CreateGymEnv() { NS_LOG_FUNCTION (this); // should never be called, only child classes: TcpRl and TcpRlTimeBased } void TcpRlBase::ConnectSocketCallbacks() { NS_LOG_FUNCTION (this); bool foundSocket = false; for (NodeList::Iterator i = NodeList::Begin (); i != NodeList::End (); ++i) { Ptr node = *i; Ptr tcp = node->GetObject (); ObjectVectorValue socketVec; tcp->GetAttribute ("SocketList", socketVec); NS_LOG_DEBUG("Node: " << node->GetId() << " TCP socket num: " << socketVec.GetN()); uint32_t sockNum = socketVec.GetN(); for (uint32_t j=0; j sockObj = socketVec.Get(j); Ptr tcpSocket = DynamicCast (sockObj); NS_LOG_DEBUG("Node: " << node->GetId() << " TCP Socket: " << tcpSocket); if(!tcpSocket) { continue; } Ptr dtcpSocket = StaticCast(tcpSocket); Ptr ca = dtcpSocket->GetCongestionControlAlgorithm(); NS_LOG_DEBUG("CA name: " << ca->GetName()); Ptr rlCa = DynamicCast(ca); if (rlCa == this) { NS_LOG_DEBUG("Found TcpRl CA!"); foundSocket = true; m_tcpSocket = tcpSocket; break; } } if (foundSocket) { break; } } NS_ASSERT_MSG(m_tcpSocket, "TCP socket was not found."); if(m_tcpSocket) { NS_LOG_DEBUG("Found TCP Socket: " << m_tcpSocket); m_tcpSocket->TraceConnectWithoutContext ("Tx", MakeCallback (&TcpGymEnv::TxPktTrace, m_tcpGymEnv)); m_tcpSocket->TraceConnectWithoutContext ("Rx", MakeCallback (&TcpGymEnv::RxPktTrace, m_tcpGymEnv)); NS_LOG_DEBUG("Connect socket callbacks " << m_tcpSocket->GetNode()->GetId()); m_tcpGymEnv->SetNodeId(m_tcpSocket->GetNode()->GetId()); } } std::string TcpRlBase::GetName () const { return "TcpRlBase"; } uint32_t TcpRlBase::GetSsThresh (Ptr state, uint32_t bytesInFlight) { NS_LOG_FUNCTION (this << state << bytesInFlight); if (!m_tcpGymEnv) { CreateGymEnv(); } uint32_t newSsThresh = 0; if (m_tcpGymEnv) { newSsThresh = m_tcpGymEnv->GetSsThresh(state, bytesInFlight); } return newSsThresh; } void TcpRlBase::IncreaseWindow (Ptr tcb, uint32_t segmentsAcked) { NS_LOG_FUNCTION (this << tcb << segmentsAcked); if (!m_tcpGymEnv) { CreateGymEnv(); } if (m_tcpGymEnv) { m_tcpGymEnv->IncreaseWindow(tcb, segmentsAcked); } } void TcpRlBase::PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt) { NS_LOG_FUNCTION (this); if (!m_tcpGymEnv) { CreateGymEnv(); } if (m_tcpGymEnv) { m_tcpGymEnv->PktsAcked(tcb, segmentsAcked, rtt); } } void TcpRlBase::CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState) { NS_LOG_FUNCTION (this); if (!m_tcpGymEnv) { CreateGymEnv(); } if (m_tcpGymEnv) { m_tcpGymEnv->CongestionStateSet(tcb, newState); } } void TcpRlBase::CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event) { NS_LOG_FUNCTION (this); if (!m_tcpGymEnv) { CreateGymEnv(); } if (m_tcpGymEnv) { m_tcpGymEnv->CwndEvent(tcb, event); } } Ptr TcpRlBase::Fork () { return CopyObject (this); } NS_OBJECT_ENSURE_REGISTERED (TcpRl); TypeId TcpRl::GetTypeId (void) { static TypeId tid = TypeId ("ns3::TcpRl") .SetParent () .SetGroupName ("Internet") .AddConstructor () .AddAttribute ("Reward", "Reward when increasing congestion window.", DoubleValue (1.0), MakeDoubleAccessor (&TcpRl::m_reward), MakeDoubleChecker ()) .AddAttribute ("Penalty", "Reward when increasing congestion window.", DoubleValue (-10.0), MakeDoubleAccessor (&TcpRl::m_penalty), MakeDoubleChecker ()) ; return tid; } TcpRl::TcpRl (void) : TcpRlBase () { NS_LOG_FUNCTION (this); } TcpRl::TcpRl (const TcpRl& sock) : TcpRlBase (sock) { NS_LOG_FUNCTION (this); } TcpRl::~TcpRl (void) { } std::string TcpRl::GetName () const { return "TcpRl"; } void TcpRl::CreateGymEnv() { NS_LOG_FUNCTION (this); Ptr env = CreateObject(); env->SetSocketUuid(TcpRlBase::GenerateUuid()); env->SetReward(m_reward); env->SetPenalty(m_penalty); m_tcpGymEnv = env; ConnectSocketCallbacks(); } NS_OBJECT_ENSURE_REGISTERED (TcpRlTimeBased); TypeId TcpRlTimeBased::GetTypeId (void) { static TypeId tid = TypeId ("ns3::TcpRlTimeBased") .SetParent () .SetGroupName ("Internet") .AddConstructor () .AddAttribute ("Duration", "Simulation Duration. Default: 10000ms", TimeValue (MilliSeconds (10000)), MakeTimeAccessor (&TcpRlTimeBased::m_duration), MakeTimeChecker ()) .AddAttribute ("StepTime", "Step interval used in TCP env. Default: 100ms", TimeValue (MilliSeconds (100)), MakeTimeAccessor (&TcpRlTimeBased::m_timeStep), MakeTimeChecker ()) .AddAttribute ("Reward", "Reward for increasing congestion window.", DoubleValue (1.0), MakeDoubleAccessor (&TcpRlTimeBased::m_reward), MakeDoubleChecker ()) .AddAttribute ("Penalty", "Penalty for increasing congestion window.", DoubleValue (-1.0), MakeDoubleAccessor (&TcpRlTimeBased::m_penalty), MakeDoubleChecker ()) ; return tid; } TcpRlTimeBased::TcpRlTimeBased (void) : TcpRlBase () { NS_LOG_FUNCTION (this); } TcpRlTimeBased::TcpRlTimeBased (const TcpRlTimeBased& sock) : TcpRlBase (sock) { NS_LOG_FUNCTION (this); } TcpRlTimeBased::~TcpRlTimeBased (void) { } std::string TcpRlTimeBased::GetName () const { return "TcpRlTimeBased"; } void TcpRlTimeBased::CreateGymEnv() { NS_LOG_FUNCTION (this); Ptr env = CreateObject (); env->SetSocketUuid(TcpRlBase::GenerateUuid()); env->SetDuration(m_duration); env->SetTimeStep(m_timeStep); env->SetReward(m_reward); env->SetPenalty(m_penalty); m_tcpGymEnv = env; ConnectSocketCallbacks(); } } // namespace ns3