#ifndef TCP_RL_ENV_H #define TCP_RL_ENV_H #include "ns3/opengym-module.h" #include "ns3/tcp-socket-base.h" #include namespace ns3 { class Packet; class TcpHeader; class TcpSocketBase; class Time; class TcpGymEnv : public OpenGymEnv { public: TcpGymEnv (); virtual ~TcpGymEnv (); static TypeId GetTypeId (void); virtual void DoDispose (); void SetNodeId(uint32_t id); void SetSocketUuid(uint32_t id); std::string GetTcpCongStateName(const TcpSocketState::TcpCongState_t state); std::string GetTcpCAEventName(const TcpSocketState::TcpCAEvent_t event); // OpenGym interface virtual Ptr GetActionSpace(); virtual bool GetGameOver(); virtual float GetReward(); virtual std::string GetExtraInfo(); virtual bool ExecuteActions(Ptr action); virtual Ptr GetObservationSpace() = 0; virtual Ptr GetObservation() = 0; // trace packets, e.g. for calculating inter tx/rx time virtual void TxPktTrace(Ptr, const TcpHeader&, Ptr) = 0; virtual void RxPktTrace(Ptr, const TcpHeader&, Ptr) = 0; // TCP congestion control interface virtual uint32_t GetSsThresh (Ptr tcb, uint32_t bytesInFlight) = 0; virtual void IncreaseWindow (Ptr tcb, uint32_t segmentsAcked) = 0; // optional functions used to collect obs virtual void PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt) = 0; virtual void CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState) = 0; virtual void CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event) = 0; typedef enum { GET_SS_THRESH = 0, INCREASE_WINDOW, PKTS_ACKED, CONGESTION_STATE_SET, CWND_EVENT, } CalledFunc_t; protected: uint32_t m_nodeId; uint32_t m_socketUuid; // state // obs has to be implemented in child class // game over bool m_isGameOver; // reward float m_envReward; // extra info std::string m_info; // actions uint32_t m_new_ssThresh; uint32_t m_new_cWnd; }; class TcpEventGymEnv : public TcpGymEnv { public: TcpEventGymEnv (); virtual ~TcpEventGymEnv (); static TypeId GetTypeId (void); virtual void DoDispose (); void SetReward(float value); void SetPenalty(float value); // OpenGym interface virtual Ptr GetObservationSpace(); Ptr GetObservation(); // trace packets, e.g. for calculating inter tx/rx time virtual void TxPktTrace(Ptr, const TcpHeader&, Ptr); virtual void RxPktTrace(Ptr, const TcpHeader&, Ptr); // TCP congestion control interface virtual uint32_t GetSsThresh (Ptr tcb, uint32_t bytesInFlight); virtual void IncreaseWindow (Ptr tcb, uint32_t segmentsAcked); // optional functions used to collect obs virtual void PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt); virtual void CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState); virtual void CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event); private: // state CalledFunc_t m_calledFunc; Ptr m_tcb; uint32_t m_bytesInFlight; uint32_t m_segmentsAcked; Time m_rtt; TcpSocketState::TcpCongState_t m_newState; TcpSocketState::TcpCAEvent_t m_event; // reward float m_reward; float m_penalty; }; class TcpTimeStepGymEnv : public TcpGymEnv { public: TcpTimeStepGymEnv (); virtual ~TcpTimeStepGymEnv (); static TypeId GetTypeId (void); virtual void DoDispose (); void SetDuration(Time value); void SetTimeStep(Time value); void SetReward(float value); void SetPenalty(float value); // OpenGym interface virtual Ptr GetObservationSpace(); Ptr GetObservation(); // trace packets, e.g. for calculating inter tx/rx time virtual void TxPktTrace(Ptr, const TcpHeader&, Ptr); virtual void RxPktTrace(Ptr, const TcpHeader&, Ptr); // TCP congestion control interface virtual uint32_t GetSsThresh (Ptr tcb, uint32_t bytesInFlight); virtual void IncreaseWindow (Ptr tcb, uint32_t segmentsAcked); // optional functions used to collect obs virtual void PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt); virtual void CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState); virtual void CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event); private: void ScheduleNextStateRead(); bool m_started {false}; Time m_duration; Time m_timeStep; // state Ptr m_tcb; std::vector m_bytesInFlight; std::vector m_segmentsAcked; uint64_t m_rttSampleNum {0}; Time m_rttSum {MicroSeconds (0.0)}; Time m_lastPktTxTime {MicroSeconds(0.0)}; Time m_lastPktRxTime {MicroSeconds(0.0)}; uint64_t m_interTxTimeNum {0}; Time m_interTxTimeSum {MicroSeconds (0.0)}; uint64_t m_interRxTimeNum {0}; Time m_interRxTimeSum {MicroSeconds (0.0)}; Time m_prevAvgRtt {MicroSeconds (0.0)}; Time m_totalAvgRttSum {MicroSeconds (0.0)}; uint64_t m_totalAvgRttNum {0}; uint32_t m_old_cWnd {0}; // reward float m_reward; float m_penalty; }; } // namespace ns3 #endif /* TCP_RL_ENV_H */