/* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ /* * Copyright (c) 2018 Technische Universität Berlin * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License version 2 as * published by the Free Software Foundation; * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA * * Author: Piotr Gawlowicz */ #ifndef TCP_RL_ENV_H #define TCP_RL_ENV_H #include "ns3/opengym-module.h" #include "ns3/tcp-socket-base.h" #include namespace ns3 { class Packet; class TcpHeader; class TcpSocketBase; class Time; class TcpGymEnv : public OpenGymEnv { public: TcpGymEnv (); virtual ~TcpGymEnv (); static TypeId GetTypeId (void); virtual void DoDispose (); void SetNodeId(uint32_t id); void SetSocketUuid(uint32_t id); std::string GetTcpCongStateName(const TcpSocketState::TcpCongState_t state); std::string GetTcpCAEventName(const TcpSocketState::TcpCAEvent_t event); // OpenGym interface virtual Ptr GetActionSpace(); virtual bool GetGameOver(); virtual float GetReward(); virtual std::string GetExtraInfo(); virtual bool ExecuteActions(Ptr action); virtual Ptr GetObservationSpace() = 0; virtual Ptr GetObservation() = 0; // trace packets, e.g. for calculating inter tx/rx time virtual void TxPktTrace(Ptr, const TcpHeader&, Ptr) = 0; virtual void RxPktTrace(Ptr, const TcpHeader&, Ptr) = 0; // TCP congestion control interface virtual uint32_t GetSsThresh (Ptr tcb, uint32_t bytesInFlight) = 0; virtual void IncreaseWindow (Ptr tcb, uint32_t segmentsAcked) = 0; // optional functions used to collect obs virtual void PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt) = 0; virtual void CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState) = 0; virtual void CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event) = 0; typedef enum { GET_SS_THRESH = 0, INCREASE_WINDOW, PKTS_ACKED, CONGESTION_STATE_SET, CWND_EVENT, } CalledFunc_t; protected: uint32_t m_nodeId; uint32_t m_socketUuid; // state // obs has to be implemented in child class // game over bool m_isGameOver; // reward float m_envReward; // extra info std::string m_info; // actions uint32_t m_new_ssThresh; uint32_t m_new_cWnd; }; class TcpEventGymEnv : public TcpGymEnv { public: TcpEventGymEnv (); virtual ~TcpEventGymEnv (); static TypeId GetTypeId (void); virtual void DoDispose (); void SetReward(float value); void SetPenalty(float value); // OpenGym interface virtual Ptr GetObservationSpace(); Ptr GetObservation(); // trace packets, e.g. for calculating inter tx/rx time virtual void TxPktTrace(Ptr, const TcpHeader&, Ptr); virtual void RxPktTrace(Ptr, const TcpHeader&, Ptr); // TCP congestion control interface virtual uint32_t GetSsThresh (Ptr tcb, uint32_t bytesInFlight); virtual void IncreaseWindow (Ptr tcb, uint32_t segmentsAcked); // optional functions used to collect obs virtual void PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt); virtual void CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState); virtual void CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event); private: // state CalledFunc_t m_calledFunc; Ptr m_tcb; uint32_t m_bytesInFlight; uint32_t m_segmentsAcked; Time m_rtt; TcpSocketState::TcpCongState_t m_newState; TcpSocketState::TcpCAEvent_t m_event; // reward float m_reward; float m_penalty; }; class TcpTimeStepGymEnv : public TcpGymEnv { public: TcpTimeStepGymEnv (); virtual ~TcpTimeStepGymEnv (); static TypeId GetTypeId (void); virtual void DoDispose (); void SetDuration(Time value); void SetTimeStep(Time value); void SetReward(float value); void SetPenalty(float value); // OpenGym interface virtual Ptr GetObservationSpace(); Ptr GetObservation(); // trace packets, e.g. for calculating inter tx/rx time virtual void TxPktTrace(Ptr, const TcpHeader&, Ptr); virtual void RxPktTrace(Ptr, const TcpHeader&, Ptr); // TCP congestion control interface virtual uint32_t GetSsThresh (Ptr tcb, uint32_t bytesInFlight); virtual void IncreaseWindow (Ptr tcb, uint32_t segmentsAcked); // optional functions used to collect obs virtual void PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt); virtual void CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState); virtual void CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event); private: void ScheduleNextStateRead(); bool m_started {false}; Time m_duration; Time m_timeStep; // state Ptr m_tcb; std::vector m_bytesInFlight; std::vector m_segmentsAcked; uint64_t m_rttSampleNum {0}; Time m_rttSum {MicroSeconds (0.0)}; Time m_lastPktTxTime {MicroSeconds(0.0)}; Time m_lastPktRxTime {MicroSeconds(0.0)}; uint64_t m_interTxTimeNum {0}; Time m_interTxTimeSum {MicroSeconds (0.0)}; uint64_t m_interRxTimeNum {0}; Time m_interRxTimeSum {MicroSeconds (0.0)}; Time m_prevAvgRtt {MicroSeconds (0.0)}; Time m_totalAvgRttSum {MicroSeconds (0.0)}; uint64_t m_totalAvgRttNum {0}; uint32_t m_old_cWnd {0}; // reward float m_reward; float m_penalty; }; } // namespace ns3 #endif /* TCP_RL_ENV_H */