aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Doan <daviddoan@Davids-MacBook-Pro-70.local>2023-11-02 18:57:56 -0400
committerDavid Doan <daviddoan@Davids-MacBook-Pro-70.local>2023-11-02 18:57:56 -0400
commit5d7ffd42b638f33d32c01c46f853e26a5028b552 (patch)
treeed7220178ce76a213341f14fb27c6af63e95209a
parenta8c9b48821c85e072ca9054621928e415540b12c (diff)
milestone 1
-rw-r--r--cmd/vhost/main.go128
-rw-r--r--cmd/vrouter/main.go8
-rw-r--r--pkg/ipstack/ipstack.go350
-rw-r--r--pkg/iptcp_utils/iptcp_utils.go144
-rw-r--r--pkg/tcpstack/tcpstack.go232
-rwxr-xr-xvhostbin3095650 -> 3157834 bytes
-rwxr-xr-xvrouterbin3095650 -> 3144747 bytes
7 files changed, 810 insertions, 52 deletions
diff --git a/cmd/vhost/main.go b/cmd/vhost/main.go
index c0e5ad8..57f5f7c 100644
--- a/cmd/vhost/main.go
+++ b/cmd/vhost/main.go
@@ -4,9 +4,11 @@ import (
"bufio"
"fmt"
"iptcp/pkg/ipstack"
+ // "iptcp/pkg/tcpstack"
"net/netip"
"os"
"strings"
+ "strconv"
)
func main() {
@@ -22,12 +24,13 @@ func main() {
return
}
ipstack.RegisterProtocolHandler(ipstack.TEST_PROTOCOL)
-
+ ipstack.RegisterProtocolHandler(ipstack.TCP_PROTOCOL)
scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() {
line := scanner.Text()
- switch line {
+ tokens := strings.Split(line, " ")
+ switch tokens[0] {
case "li":
fmt.Println("Name\tAddr/Prefix\tState")
fmt.Println(ipstack.SprintInterfaces())
@@ -43,59 +46,90 @@ func main() {
case "exit":
ipstack.CleanUp()
os.Exit(0)
- default:
- if len(line) > 4 {
- if line[:4] == "down" {
- ifaceName := line[5:]
- ipstack.InterfaceDownREPL(ifaceName)
- }
+ case "down":
+ // get interface name
+ ifaceName := tokens[1]
+ ipstack.InterfaceDownREPL(ifaceName)
+ case "up":
+ // get interface name
+ ifaceName := tokens[1]
+ ipstack.InterfaceUpREPL(ifaceName)
+ case "send":
+ // get IP address and message that follows it
+ IPAndMessage := strings.Split(line, " ")
+ ipAddr := IPAndMessage[1]
+ message := IPAndMessage[2:]
- if line[:4] == "send" {
- // get IP address and message that follows it
- IPAndMessage := strings.Split(line, " ")
- ipAddr := IPAndMessage[1]
- message := IPAndMessage[2:]
+ // combine message into one string
+ messageToSend := strings.Join(message, " ")
+ messageToSendBytes := []byte(messageToSend)
- // combine message into one string
- messageToSend := strings.Join(message, " ")
- messageToSendBytes := []byte(messageToSend)
-
- address, _ := netip.ParseAddr(ipAddr)
- hop, err := ipstack.Route(address)
+ address, _ := netip.ParseAddr(ipAddr)
+ hop, err := ipstack.Route(address)
+ if err != nil {
+ fmt.Println(err)
+ continue
+ }
+ myAddr := hop.Interface.IpPrefix.Addr()
+ for _, neighbor := range ipstack.GetNeighbors()[hop.Interface.Name] {
+ if neighbor.VipAddr == address ||
+ neighbor.VipAddr == hop.VIP && hop.Type == "S" {
+ bytesWritten, err := ipstack.SendIP(&myAddr, neighbor, ipstack.TEST_PROTOCOL, messageToSendBytes, ipAddr, nil)
+ fmt.Printf("Sent %d bytes to %s\n", bytesWritten, neighbor.VipAddr.String())
if err != nil {
fmt.Println(err)
- continue
- }
- myAddr := hop.Interface.IpPrefix.Addr()
- for _, neighbor := range ipstack.GetNeighbors()[hop.Interface.Name] {
- if neighbor.VipAddr == address ||
- neighbor.VipAddr == hop.VIP && hop.Type == "S" {
- bytesWritten, err := ipstack.SendIP(&myAddr, neighbor, ipstack.TEST_PROTOCOL, messageToSendBytes, ipAddr, nil)
- fmt.Printf("Sent %d bytes to %s\n", bytesWritten, neighbor.VipAddr.String())
- if err != nil {
- fmt.Println(err)
- }
- }
}
}
}
- if len(line) > 2 {
- if line[:2] == "up" {
- // get interface name
- ifaceName := line[3:]
- ipstack.InterfaceUpREPL(ifaceName)
- }
- } else {
- fmt.Println("Invalid command: ", line)
- fmt.Println("Commands: ")
- fmt.Println(" exit Terminate this program")
- fmt.Println(" li List interfaces")
- fmt.Println(" lr List routes")
- fmt.Println(" ln List available neighbors")
- fmt.Println(" up Enable an interface")
- fmt.Println(" down Disable an interface")
- fmt.Println(" send Send test packet")
+// ********************************************TCP REPL***************************************************************
+ case "a":
+ // accept a connection
+ // get the port number
+ port := line[2:]
+ uint16Port, err := strconv.ParseUint(port, 10, 16)
+ if err != nil {
+ fmt.Println(err)
+ continue
+ }
+ // listen on the port
+ listener, err := ipstack.VListen(uint16(uint16Port))
+ if err != nil {
+ fmt.Println(err)
+ continue
+ }
+ _, err = listener.VAccept()
+ if err != nil {
+ fmt.Println(err)
+ continue
}
+ case "c":
+ // connect to a port
+ vipAddr := tokens[1]
+ port := tokens[2]
+ uint16Port, err := strconv.ParseUint(port, 10, 16)
+ if err != nil {
+ fmt.Println(err)
+ continue
+ }
+ _, err = ipstack.VConnect(vipAddr, uint16(uint16Port))
+ if err != nil {
+ fmt.Println(err)
+ continue
+ }
+ case "ls":
+ // list sockets
+ fmt.Println("SID\tLAddr\t\tLPort\tRAddr\t\tRPort\tRStatus")
+ fmt.Println(ipstack.SprintSockets())
+ default:
+ fmt.Println("Invalid command: ", line)
+ fmt.Println("Commands: ")
+ fmt.Println(" exit Terminate this program")
+ fmt.Println(" li List interfaces")
+ fmt.Println(" lr List routes")
+ fmt.Println(" ln List available neighbors")
+ fmt.Println(" up Enable an interface")
+ fmt.Println(" down Disable an interface")
+ fmt.Println(" send Send test packet")
continue
}
}
diff --git a/cmd/vrouter/main.go b/cmd/vrouter/main.go
index 2736525..f338b80 100644
--- a/cmd/vrouter/main.go
+++ b/cmd/vrouter/main.go
@@ -4,9 +4,11 @@ import (
"bufio"
"fmt"
"iptcp/pkg/ipstack"
+ // "iptcp/pkg/tcpstack"
"net/netip"
"os"
"strings"
+ // "strconv"
)
func main() {
@@ -30,13 +32,17 @@ func main() {
ipstack.RegisterProtocolHandler(ipstack.TEST_PROTOCOL)
// register the rip protocol for the router
ipstack.RegisterProtocolHandler(ipstack.RIP_PROTOCOL)
+ // register the TCP protocol for the handler
+ ipstack.RegisterProtocolHandler(ipstack.TCP_PROTOCOL)
+
// create a scanner to read from stdin for command-line inputs
scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() {
line := scanner.Text()
- switch line {
+ tokens := strings.Split(line, " ")
+ switch tokens[0] {
// print the interfaces
case "li":
fmt.Println("Name\tAddr/Prefix\tState")
diff --git a/pkg/ipstack/ipstack.go b/pkg/ipstack/ipstack.go
index fe8d743..78a6b96 100644
--- a/pkg/ipstack/ipstack.go
+++ b/pkg/ipstack/ipstack.go
@@ -20,11 +20,15 @@ import (
"github.com/google/netstack/tcpip/header"
"github.com/pkg/errors"
"iptcp/pkg/lnxconfig"
+ "iptcp/pkg/iptcp_utils"
"log"
"net"
"net/netip"
"sync"
"time"
+ "strings"
+ "math/rand"
+ // "github.com/google/netstack/tcpip/header"
)
const (
@@ -35,6 +39,7 @@ const (
SIZE_OF_RIP_ENTRY = 12
RIP_PROTOCOL = 200
TEST_PROTOCOL = 0
+ TCP_PROTOCOL = 6
SIZE_OF_RIP_HEADER = 4
MAX_TIMEOUT = 12
)
@@ -682,16 +687,20 @@ func StartRipRoutines() {
// RegisterProtocolHandler registers a protocol handler for a given protocol number
// Returns true if the protocol number is valid, false otherwise
func RegisterProtocolHandler(protocolNum int) bool {
- if protocolNum == RIP_PROTOCOL {
+ switch protocolNum {
+ case RIP_PROTOCOL:
protocolHandlers[protocolNum] = HandleRIP
go StartRipRoutines()
return true
- }
- if protocolNum == TEST_PROTOCOL {
+ case TEST_PROTOCOL:
protocolHandlers[protocolNum] = HandleTestPackets
return true
+ case TCP_PROTOCOL:
+ protocolHandlers[protocolNum] = HandleTCP
+ return true
+ default:
+ return false
}
- return false
}
// HandleRIP handles incoming RIP packets in the following way:
@@ -853,6 +862,132 @@ func HandleTestPackets(src *Interface, message []byte, hdr *ipv4header.IPv4Heade
return nil
}
+func HandleTCP(src *Interface, message []byte, hdr *ipv4header.IPv4Header) error {
+ fmt.Println("I see a TCP packet")
+ // ipHeaderSize := hdr.Len
+ tcpHeaderAndData := message
+ tcpHdr := iptcp_utils.ParseTCPHeader(tcpHeaderAndData)
+ tcpPayload := tcpHeaderAndData[tcpHdr.DataOffset:]
+ tcpChecksumFromHeader := tcpHdr.Checksum
+ tcpHdr.Checksum = 0
+ tcpComputedChecksum := iptcp_utils.ComputeTCPChecksum(&tcpHdr, hdr.Src, hdr.Dst, tcpPayload)
+
+ var tcpChecksumState string
+ if tcpComputedChecksum == tcpChecksumFromHeader {
+ tcpChecksumState = "OK"
+ } else {
+ tcpChecksumState = "FAIL"
+ }
+
+ if tcpChecksumState == "FAIL" {
+ // drop the packet
+ fmt.Println("checksum failed, dropping packet")
+ return nil
+ }
+
+ switch tcpHdr.Flags {
+ case 0x02:
+ fmt.Println("I see a SYN flag")
+ // if the SYN flag is set, then send a SYNACK
+ available := false
+ mapMutex.Lock()
+ for _, socketEntry := range VHostSocketMaps {
+ if socketEntry.LocalPort == tcpHdr.DstPort && socketEntry.LocalIP == hdr.Dst.String() && socketEntry.State == Listening{
+ // add a new socketEntry to the map
+ newEntry := &SocketEntry{
+ LocalPort: tcpHdr.DstPort,
+ RemotePort: tcpHdr.SrcPort,
+ LocalIP: hdr.Dst.String(),
+ RemoteIP: hdr.Src.String(),
+ State: SYNRECIEVED,
+ Socket: socketsMade,
+ }
+ socketsMade += 1
+ // add the entry to the map
+ VHostSocketMaps[socketsMade] = newEntry
+ available = true
+ break
+ }
+ }
+ mapMutex.Unlock()
+ if !available {
+ fmt.Println("no socket available")
+ return nil
+ }
+ // make the header
+ tcpHdr := &header.TCPFields{
+ SrcPort: tcpHdr.DstPort,
+ DstPort: tcpHdr.SrcPort,
+ SeqNum: tcpHdr.SeqNum,
+ AckNum: tcpHdr.SeqNum + 1,
+ DataOffset: 20,
+ Flags: 0x12,
+ WindowSize: MAX_WINDOW_SIZE,
+ Checksum: 0,
+ UrgentPointer: 0,
+ }
+ // make the payload
+ synAckPayload := []byte{}
+ err := SendTCP(tcpHdr, synAckPayload, hdr.Dst, hdr.Src)
+ if err != nil {
+ fmt.Println(err)
+ }
+ break
+ case 0x12:
+ fmt.Println("I see a SYNACK flag")
+ // lookup for socket entry and update its state
+ mapMutex.Lock()
+ for _, socketEntry := range VHostSocketMaps {
+ if socketEntry.LocalPort == tcpHdr.DstPort && socketEntry.LocalIP == hdr.Dst.String() && socketEntry.State == SYNSENT {
+ socketEntry.State = Established
+ break
+ }
+ }
+ mapMutex.Unlock()
+
+ // send an ACK
+ // make the header
+ tcpHdr := &header.TCPFields{
+ SrcPort: tcpHdr.DstPort,
+ DstPort: tcpHdr.SrcPort,
+ SeqNum: tcpHdr.SeqNum,
+ AckNum: tcpHdr.SeqNum + 1,
+ DataOffset: 20,
+ Flags: 0x10,
+ WindowSize: MAX_WINDOW_SIZE,
+ Checksum: 0,
+ UrgentPointer: 0,
+ }
+ // make the payload
+ ackPayload := []byte{}
+ err := SendTCP(tcpHdr, ackPayload, hdr.Dst, hdr.Src)
+ if err != nil {
+ fmt.Println(err)
+ }
+ break
+ case 0x10:
+ fmt.Println("I see an ACK flag")
+ // lookup for socket entry and update its state
+ // set synChan to true (TODO)
+
+ mapMutex.Lock()
+ for _, socketEntry := range VHostSocketMaps {
+ if socketEntry.LocalPort == tcpHdr.DstPort && socketEntry.LocalIP == hdr.Dst.String() && socketEntry.State == SYNRECIEVED {
+ socketEntry.State = Established
+ break
+ }
+ }
+ mapMutex.Unlock()
+ break
+ default:
+ fmt.Println("I see a non TCP packet")
+ break
+ }
+
+
+ return nil
+}
+
// *********************************************** HELPERS **********************************************************
// Route returns the next HOP, based on longest prefix match for a given ip
@@ -1006,3 +1141,210 @@ func CleanUp() {
time.Sleep(5 * time.Millisecond)
}
+
+// ************************************** TCP FUNCTIONS **********************************************************
+
+type ConnectionState string
+const (
+ Established ConnectionState = "ESTABLISHED"
+ Listening ConnectionState = "LISTENING"
+ Closed ConnectionState = "CLOSED"
+ SYNSENT ConnectionState = "SYNSENT"
+ SYNRECIEVED ConnectionState = "SYNRECIEVED"
+ MAX_WINDOW_SIZE = 65535
+)
+
+// VTCPListener represents a listener socket (similar to Go’s net.TCPListener)
+type VTCPListener struct {
+ LocalAddr string
+ LocalPort uint16
+ RemoteAddr string
+ RemotePort uint16
+ Socket int
+ State ConnectionState
+}
+
+// // VTCPConn represents a “normal” socket for a TCP connection between two endpoints (similar to Go’s net.TCPConn)
+type VTCPConn struct {
+ LocalAddr string
+ LocalPort uint16
+ RemoteAddr string
+ RemotePort uint16
+ Socket int
+ State ConnectionState
+ Buffer []byte
+}
+
+type SocketEntry struct {
+ Socket int
+ LocalIP string
+ LocalPort uint16
+ RemoteIP string
+ RemotePort uint16
+ State ConnectionState
+}
+
+// create a socket map to print the local and remote ip and port as well as the state of the socket
+var VHostSocketMaps = make(map[int]*SocketEntry)
+var mapMutex = &sync.Mutex{}
+var socketsMade = 0
+var startingSeqNum = rand.Uint32()
+
+// Listen Sockets
+func VListen(port uint16) (*VTCPListener, error) {
+ myIP := GetInterfaces()[0].IpPrefix.Addr()
+ listener := &VTCPListener{
+ Socket: socketsMade,
+ State: Listening,
+ LocalPort: port,
+ LocalAddr: myIP.String(),
+ }
+
+ // add the socket to the socket map
+ mapMutex.Lock()
+ VHostSocketMaps[socketsMade] = &SocketEntry{
+ Socket: socketsMade,
+ LocalIP: myIP.String(),
+ LocalPort: port,
+ RemoteIP: "0.0.0.0",
+ RemotePort: 0,
+ State: Listening,
+ }
+ mapMutex.Unlock()
+ socketsMade += 1
+ return listener, nil
+
+}
+
+func (l *VTCPListener) VAccept() (*VTCPConn, error) {
+ // synChan = make(chan bool)
+ for {
+ // wait for a SYN request
+ mapMutex.Lock()
+ for _, socketEntry := range VHostSocketMaps {
+ if socketEntry.State == Established {
+ // create a new VTCPConn
+ conn := &VTCPConn{
+ LocalAddr: socketEntry.LocalIP,
+ LocalPort: socketEntry.LocalPort,
+ RemoteAddr: socketEntry.RemoteIP,
+ RemotePort: socketEntry.RemotePort,
+ Socket: socketEntry.Socket,
+ State: Established,
+ }
+ return conn, nil
+ }
+ }
+ mapMutex.Unlock()
+
+ }
+}
+
+func GetRandomPort() uint16 {
+ const (
+ minDynamicPort = 49152
+ maxDynamicPort = 65535
+ )
+ return uint16(rand.Intn(maxDynamicPort - minDynamicPort) + minDynamicPort)
+}
+
+func VConnect(ip string, port uint16) (*VTCPConn, error) {
+ // get my ip address
+ myIP := GetInterfaces()[0].IpPrefix.Addr()
+ // get random port
+ portRand := GetRandomPort()
+
+ tcpHdr := &header.TCPFields{
+ SrcPort: portRand,
+ DstPort: port,
+ SeqNum: startingSeqNum,
+ AckNum: startingSeqNum,
+ DataOffset: 20,
+ Flags: header.TCPFlagSyn,
+ WindowSize: MAX_WINDOW_SIZE,
+ Checksum: 0,
+ UrgentPointer: 0,
+ }
+ payload := []byte{}
+ ipParsed, err := netip.ParseAddr(ip)
+ if err != nil {
+ return nil, err
+ }
+ err = SendTCP(tcpHdr, payload, myIP, ipParsed)
+ if err != nil {
+ return nil, err
+ }
+
+ conn := &VTCPConn{
+ LocalAddr: myIP.String(),
+ LocalPort: portRand,
+ RemoteAddr: ip,
+ RemotePort: port,
+ Socket: socketsMade,
+ State: Established,
+ Buffer: []byte{},
+ }
+
+ // add the socket to the socket map
+ mapMutex.Lock()
+ VHostSocketMaps[socketsMade] = &SocketEntry{
+ Socket: socketsMade,
+ LocalIP: myIP.String(),
+ LocalPort: portRand,
+ RemoteIP: ip,
+ RemotePort: port,
+ State: SYNSENT,
+ }
+ mapMutex.Unlock()
+ socketsMade += 1
+
+ return conn, nil
+}
+
+func SendTCP(tcpHdr *header.TCPFields, payload []byte, myIP netip.Addr, ipParsed netip.Addr) error {
+ checksum := iptcp_utils.ComputeTCPChecksum(tcpHdr, myIP, ipParsed, payload)
+ tcpHdr.Checksum = checksum
+
+ tcpHeaderBytes := make(header.TCP, iptcp_utils.TcpHeaderLen)
+ tcpHeaderBytes.Encode(tcpHdr)
+
+ ipPacketPayload := make([]byte, 0, len(tcpHeaderBytes)+len(payload))
+ ipPacketPayload = append(ipPacketPayload, tcpHeaderBytes...)
+ ipPacketPayload = append(ipPacketPayload, []byte(payload)...)
+
+ // lookup neighbor
+ address := ipParsed
+ hop, err := Route(address)
+ if err != nil {
+ fmt.Println(err)
+ return err
+ }
+ myAddr := hop.Interface.IpPrefix.Addr()
+
+ for _, neighbor := range GetNeighbors()[hop.Interface.Name] {
+ if neighbor.VipAddr == address ||
+ neighbor.VipAddr == hop.VIP && hop.Type == "S" {
+ bytesWritten, err := SendIP(&myAddr, neighbor, TCP_PROTOCOL, ipPacketPayload, ipParsed.String(), nil)
+ fmt.Printf("Sent %d bytes to %s\n", bytesWritten, neighbor.VipAddr.String())
+ if err != nil {
+ fmt.Println(err)
+ }
+ }
+ }
+ return nil
+}
+
+func SprintSockets() string {
+ tmp := ""
+ for _, socket := range VHostSocketMaps {
+ // remove the spaces of the local and remote ip variables
+ socket.LocalIP = strings.ReplaceAll(socket.LocalIP, " ", "")
+ socket.RemoteIP = strings.ReplaceAll(socket.RemoteIP, " ", "")
+ if socket.RemotePort == 0 {
+ tmp += fmt.Sprintf("%d\t%s\t%d\t%s\t\t%d\t%s\n", socket.Socket, socket.LocalIP, socket.LocalPort, socket.RemoteIP, socket.RemotePort, socket.State)
+ continue
+ }
+ tmp += fmt.Sprintf("%d\t%s\t%d\t%s\t%d\t%s\n", socket.Socket, socket.LocalIP, socket.LocalPort, socket.RemoteIP, socket.RemotePort, socket.State)
+ }
+ return tmp
+} \ No newline at end of file
diff --git a/pkg/iptcp_utils/iptcp_utils.go b/pkg/iptcp_utils/iptcp_utils.go
new file mode 100644
index 0000000..f8caf4d
--- /dev/null
+++ b/pkg/iptcp_utils/iptcp_utils.go
@@ -0,0 +1,144 @@
+package iptcp_utils
+
+import (
+ "encoding/binary"
+ "fmt"
+ "net/netip"
+ "strings"
+
+ "github.com/google/netstack/tcpip/header"
+)
+
+const (
+ TcpHeaderLen = header.TCPMinimumSize
+ TcpPseudoHeaderLen = 12
+ IpProtoTcp = header.TCPProtocolNumber
+ MaxVirtualPacketSize = 1400
+)
+
+
+
+// Build a TCPFields struct from the TCP byte array
+//
+// NOTE: the netstack package might have other options for parsing the header
+// that you may like better--this example is most similar to our other class
+// examples. Your mileage may vary!
+func ParseTCPHeader(b []byte) header.TCPFields {
+ td := header.TCP(b)
+ return header.TCPFields{
+ SrcPort: td.SourcePort(),
+ DstPort: td.DestinationPort(),
+ SeqNum: td.SequenceNumber(),
+ AckNum: td.AckNumber(),
+ DataOffset: td.DataOffset(),
+ Flags: td.Flags(),
+ WindowSize: td.WindowSize(),
+ Checksum: td.Checksum(),
+ }
+}
+
+// The TCP checksum is computed based on a "pesudo-header" that
+// combines the (virtual) IP source and destination address, protocol value,
+// as well as the TCP header and payload
+//
+// This is one example one is way to combine all of this information
+// and compute the checksum leveraging the netstack package.
+//
+// For more details, see the "Checksum" component of RFC9293 Section 3.1,
+// https://www.rfc-editor.org/rfc/rfc9293.txt
+func ComputeTCPChecksum(tcpHdr *header.TCPFields,
+ sourceIP netip.Addr, destIP netip.Addr, payload []byte) uint16 {
+
+ // Fill in the pseudo header
+ pseudoHeaderBytes := make([]byte, TcpPseudoHeaderLen)
+
+ // First are the source and dest IPs. This function only supports
+ // IPv4, so make sure the IPs are IPv4 addresses
+ copy(pseudoHeaderBytes[0:4], sourceIP.AsSlice())
+ copy(pseudoHeaderBytes[4:8], destIP.AsSlice())
+
+ // Next, add the protocol number and header length
+ pseudoHeaderBytes[8] = uint8(0)
+ pseudoHeaderBytes[9] = uint8(IpProtoTcp)
+
+ totalLength := TcpHeaderLen + len(payload)
+ binary.BigEndian.PutUint16(pseudoHeaderBytes[10:12], uint16(totalLength))
+
+ // Turn the TcpFields struct into a byte array
+ headerBytes := header.TCP(make([]byte, TcpHeaderLen))
+ headerBytes.Encode(tcpHdr)
+
+ // Compute the checksum for each individual part and combine To combine the
+ // checksums, we leverage the "initial value" argument of the netstack's
+ // checksum package to carry over the value from the previous part
+ pseudoHeaderChecksum := header.Checksum(pseudoHeaderBytes, 0)
+ headerChecksum := header.Checksum(headerBytes, pseudoHeaderChecksum)
+ fullChecksum := header.Checksum(payload, headerChecksum)
+
+ // Return the inverse of the computed value,
+ // which seems to be the convention of the checksum algorithm
+ // in the netstack package's implementation
+ return fullChecksum ^ 0xffff
+}
+
+// Compute the checksum using the netstack package
+func ComputeIPChecksum(b []byte) uint16 {
+ checksum := header.Checksum(b, 0)
+
+ // Invert the checksum value. Why is this necessary?
+ // This function returns the inverse of the checksum
+ // on an initial computation. While this may seem weird,
+ // it makes it easier to use this same function
+ // to validate the checksum on the receiving side.
+ // See ValidateChecksum in the receiver file for details.
+ checksumInv := checksum ^ 0xffff
+
+ return checksumInv
+}
+
+// Validate the checksum using the netstack package Here, we provide both the
+// byte array for the header AND the initial checksum value that was stored in
+// the header
+//
+// "Why don't we need to set the checksum value to 0 first?"
+//
+// Normally, the checksum is computed with the checksum field of the header set
+// to 0. This library creatively avoids this step by instead subtracting the
+// initial value from the computed checksum. If you use a different language or
+// checksum function, you may need to handle this differently.
+func ValidateIPChecksum(b []byte, fromHeader uint16) uint16 {
+ checksum := header.Checksum(b, fromHeader)
+
+ return checksum
+}
+
+// Pretty-print TCP flags value as a string
+func TCPFlagsAsString(flags uint8) string {
+ strMap := map[uint8]string{
+ header.TCPFlagAck: "ACK",
+ header.TCPFlagFin: "FIN",
+ header.TCPFlagPsh: "PSH",
+ header.TCPFlagRst: "RST",
+ header.TCPFlagSyn: "SYN",
+ header.TCPFlagUrg: "URG",
+ }
+
+ matches := make([]string, 0)
+
+ for b, str := range strMap {
+ if (b & flags) == b {
+ matches = append(matches, str)
+ }
+ }
+
+ ret := strings.Join(matches, "+")
+
+ return ret
+}
+
+// Pretty-print a TCP header (with pretty-printed flags)
+// Otherwise, using %+v in format strings is a good enough view in most cases
+func TCPFieldsToString(hdr *header.TCPFields) string {
+ return fmt.Sprintf("{SrcPort:%d DstPort:%d, SeqNum:%d AckNum:%d DataOffset:%d Flags:%s WindowSize:%d Checksum:%x UrgentPointer:%d}",
+ hdr.SrcPort, hdr.DstPort, hdr.SeqNum, hdr.AckNum, hdr.DataOffset, TCPFlagsAsString(hdr.Flags), hdr.WindowSize, hdr.Checksum, hdr.UrgentPointer)
+} \ No newline at end of file
diff --git a/pkg/tcpstack/tcpstack.go b/pkg/tcpstack/tcpstack.go
new file mode 100644
index 0000000..ec451a7
--- /dev/null
+++ b/pkg/tcpstack/tcpstack.go
@@ -0,0 +1,232 @@
+package tcpstack
+
+// import (
+// // "encoding/binary"
+// "fmt"
+// // "syscall"
+// "iptcp/pkg/ipstack"
+// // ipv4header "github.com/brown-csci1680/iptcp-headers"
+// // tcpheader "github.com/brown-csci1680/iptcp-headers"
+// "github.com/google/netstack/tcpip/header"
+// "github.com/pkg/errors"
+// "iptcp/pkg/iptcp_utils"
+// // "log"
+// // "net"
+// "net/netip"
+// // "sync"
+// // "time"
+// "math/rand"
+// "sync"
+// "strings"
+// )
+
+// type ConnectionState string
+// const (
+// Established ConnectionState = "ESTABLISHED"
+// Listening ConnectionState = "LISTENING"
+// Closed ConnectionState = "CLOSED"
+// SYNSENT ConnectionState = "SYNSENT"
+// MAX_WINDOW_SIZE = 65535
+// )
+
+// // VTCPListener represents a listener socket (similar to Go’s net.TCPListener)
+// type VTCPListener struct {
+// LocalAddr string
+// LocalPort uint16
+// RemoteAddr string
+// RemotePort uint16
+// Socket int
+// State ConnectionState
+// }
+
+// // // VTCPConn represents a “normal” socket for a TCP connection between two endpoints (similar to Go’s net.TCPConn)
+// type VTCPConn struct {
+// LocalAddr string
+// LocalPort uint16
+// RemoteAddr string
+// RemotePort uint16
+// Socket int
+// State ConnectionState
+// Buffer []byte
+// }
+
+// type SocketEntry struct {
+// Socket int
+// LocalIP string
+// LocalPort uint16
+// RemoteIP string
+// RemotePort uint16
+// State ConnectionState
+// }
+
+// // create a socket map to print the local and remote ip and port as well as the state of the socket
+// var VHostSocketMaps = make(map[string]map[string]*SocketEntry)
+// var mapMutex = &sync.Mutex{}
+// var socketsMade = 0
+// var myIP = ipstack.GetInterfaces()[0].IpPrefix.Addr()
+
+// // Listen Sockets
+// func VListen(port uint16) (*VTCPListener, error) {
+// // get my ip address
+// // myIP := ipstack.GetInterfaces()[0].IpPrefix.Addr()
+// listener := &VTCPListener{
+// Socket: socketsMade,
+// State: Listening,
+// LocalPort: port,
+// LocalAddr: myIP.String(),
+// }
+
+// vhostMap, ok := VHostSocketMaps[myIP.String()]
+// if !ok {
+// vhostMap = make(map[string]*SocketEntry)
+// VHostSocketMaps[fmt.Sprintf("%s:%d", "", port)] = vhostMap
+// }
+// // add the socket to the socket map
+// mapMutex.Lock()
+// vhostMap[fmt.Sprintf("%s:%d", "", port)] = &SocketEntry{
+// Socket: socketsMade,
+// LocalIP: "0.0.0.0",
+// LocalPort: port,
+// RemoteIP: "0.0.0.0",
+// RemotePort: 0,
+// State: Listening,
+// }
+// mapMutex.Unlock()
+// socketsMade += 1
+// return listener, nil
+
+// }
+
+// func (l *VTCPListener) VAccept() (*VTCPConn, error) {
+
+// for {
+// // wait for a SYN request
+// // if there is a SYN request, create a new socket, send a SYN-ACK response, and return the socket
+// myInterface := ipstack.GetInterfaces()[0]
+// err := ipstack.RecvIP(myInterface, nil)
+// if err != nil {
+// return nil, err
+// } else {
+// break
+// }
+// }
+// socketsMade += 1
+// conn := &VTCPConn{
+// }
+
+// // add the socket to the socket map
+// // mapMutex.Lock()
+
+// return conn, nil
+// // create a new socket
+// // return the socket
+// }
+
+// func GetRandomPort() uint16 {
+// const (
+// minDynamicPort = 49152
+// maxDynamicPort = 65535
+// )
+// return uint16(rand.Intn(maxDynamicPort - minDynamicPort) + minDynamicPort)
+// }
+
+// func VConnect(ip string, port uint16) (*VTCPConn, error) {
+// // get my ip address
+// // myIP := ipstack.GetInterfaces()[0].IpPrefix.Addr()
+// // get random port
+// portRand := GetRandomPort()
+// // Create a new VTCPConn.
+// conn := &VTCPConn{
+// LocalAddr: myIP.String(),
+// LocalPort: portRand,
+// RemoteAddr: ip,
+// RemotePort: port,
+// Socket: socketsMade,
+// State: SYNSENT,
+// Buffer: []byte{},
+// }
+// // get the socket entry from the socket map in the given port
+// vhostMap, ok := VHostSocketMaps[ip]
+// if !ok {
+// return nil, errors.New("socket not found")
+// }
+
+// socketEntry, ok := vhostMap[fmt.Sprintf("%s:%d", ip, port)]
+// if !ok {
+// return nil, errors.New("socket not found")
+// }
+
+// // if the socket is closed, return an error
+// if socketEntry.State == Closed {
+// return nil, errors.New("socket is closed")
+// }
+
+// // lookup neighbor
+// neighbors := ipstack.GetNeighbors()[ipstack.GetInterfaces()[0].Name]
+// var neighbor *ipstack.Neighbor
+// for _, n := range neighbors {
+// if n.VipAddr.String() == ip {
+// neighbor = n
+// }
+// }
+
+// // Send SYN request.
+// if socketEntry.State == Listening {
+// tcpHdr := &header.TCPFields{
+// SrcPort: portRand,
+// DstPort: port,
+// SeqNum: 1,
+// AckNum: 0,
+// DataOffset: 20,
+// Flags: header.TCPFlagSyn,
+// WindowSize: MAX_WINDOW_SIZE,
+// Checksum: 0,
+// UrgentPointer: 0,
+// }
+// payload := []byte{}
+// ipParsed, err := netip.ParseAddr(ip)
+// if err != nil {
+// return nil, err
+// }
+// checksum := iptcp_utils.ComputeTCPChecksum(tcpHdr, myIP, ipParsed, []byte{})
+// tcpHdr.Checksum = checksum
+
+// tcpHeaderBytes := make(header.TCP, iptcp_utils.TcpHeaderLen)
+// tcpHeaderBytes.Encode(tcpHdr)
+
+// ipPacketPayload := make([]byte, 0, len(tcpHeaderBytes)+len(payload))
+// ipPacketPayload = append(ipPacketPayload, tcpHeaderBytes...)
+// ipPacketPayload = append(ipPacketPayload, []byte(payload)...)
+
+
+// _, err = ipstack.SendIP(&myIP, neighbor, 6, ipPacketPayload, ip, nil)
+// if err != nil {
+// return nil, err
+// }
+// }
+
+// // wait for a SYN-ACK response in the socket buffer
+// myIPinterface := ipstack.GetInterfaces()[0]
+// for {
+// err := ipstack.RecvIP(myIPinterface, nil)
+// if err != nil {
+// return nil, err
+// } else {
+// break
+// }
+// }
+
+// return conn, nil
+// }
+
+// func SprintSockets() string {
+// var socketStrings []string
+// // for _, socket := range SocketMap {
+// // socketStrings = append(socketStrings, fmt.Sprintf("%d\t%s\t%d\t%s\t%d\t%s", socket.Socket, socket.LocalIP, socket.LocalPort, socket.RemoteIP, socket.RemotePort, socket.State))
+// // }
+// // only print the sockets in the map with myIP as the key
+// for _, socket := range VHostSocketMaps[myIP.String()] {
+// socketStrings = append(socketStrings, fmt.Sprintf("%d\t%s\t%d\t%s\t%d\t%s", socket.Socket, socket.LocalIP, socket.LocalPort, socket.RemoteIP, socket.RemotePort, socket.State))
+// }
+// return strings.Join(socketStrings, "\t")
+// } \ No newline at end of file
diff --git a/vhost b/vhost
index e967990..350313e 100755
--- a/vhost
+++ b/vhost
Binary files differ
diff --git a/vrouter b/vrouter
index 88d094f..8597bff 100755
--- a/vrouter
+++ b/vrouter
Binary files differ