diff options
Diffstat (limited to 'pkg/ipstack/ipstack.go')
-rw-r--r-- | pkg/ipstack/ipstack.go | 413 |
1 files changed, 413 insertions, 0 deletions
diff --git a/pkg/ipstack/ipstack.go b/pkg/ipstack/ipstack.go new file mode 100644 index 0000000..be8bc1e --- /dev/null +++ b/pkg/ipstack/ipstack.go @@ -0,0 +1,413 @@ +package ipstack + +import ( + "fmt" + ipv4header "github.com/brown-csci1680/iptcp-headers" + "github.com/google/netstack/tcpip/header" + "github.com/pkg/errors" + "iptcp/pkg/lnxconfig" + "log" + "net" + "net/netip" + "time" +) + +const ( + MAX_IP_PACKET_SIZE = 1400 + LOCAL_COST uint32 = 0 + STATIC_COST uint32 = 4294967295 // 2^32 - 1 +) + +// STRUCTS --------------------------------------------------------------------- +type Interface struct { + Name string + IpPrefix netip.Prefix + UdpAddr netip.AddrPort + + RecvSocket net.UDPConn + SocketChannel chan bool + State bool +} + +type Neighbor struct { + VipAddr netip.Addr + UdpAddr netip.AddrPort + + SendSocket net.UDPConn + SocketChannel chan bool +} + +type RIPMessage struct { + command uint8 + numEntries uint8 + entries []RIPEntry +} + +type RIPEntry struct { + addr netip.Addr + cost uint32 + mask netip.Prefix +} + +type Hop struct { + Cost uint32 + VipAsStr string +} + +// GLOBAL VARIABLES (data structures) ------------------------------------------ +var myVIP netip.Addr +var myInterfaces []*Interface +var myNeighbors = make(map[string][]*Neighbor) + +// var myRIPNeighbors = make(map[string]Neighbor) +type HandlerFunc func(int, string, *[]byte) error + +var protocolHandlers = make(map[uint16]HandlerFunc) + +// var routingTable = routingtable.New() +var routingTable = make(map[netip.Prefix]Hop) + +// reference: https://github.com/brown-csci1680/lecture-examples/blob/main/ip-demo/cmd/udp-ip-recv/main.go +func createUDPConn(UdpAddr netip.AddrPort, conn *net.UDPConn) error { + listenString := UdpAddr.String() + listenAddr, err := net.ResolveUDPAddr("udp4", listenString) + if err != nil { + return errors.WithMessage(err, "Error resolving address->\t"+listenString) + } + tmpConn, err := net.ListenUDP("udp4", listenAddr) + if err != nil { + return errors.WithMessage(err, "Could not bind to UDP port->\t"+listenString) + } + *conn = *tmpConn + + return nil +} + +func Initialize(lnxFilePath string) error { + //if len(os.Args) != 2 { + // fmt.Printf("Usage: %s <configFile>\n", os.Args[0]) + // os.Exit(1) + //} + //lnxFilePath := os.Args[1] + + // Parse the file + lnxConfig, err := lnxconfig.ParseConfig(lnxFilePath) + if err != nil { + return errors.WithMessage(err, "Error parsing config file->\t"+lnxFilePath) + } + + // 1) initialize the interfaces on this node here and into the routing table + static := false + for _, iface := range lnxConfig.Interfaces { + prefix := netip.PrefixFrom(iface.AssignedIP, iface.AssignedPrefix.Bits()) + i := &Interface{ + Name: iface.Name, + IpPrefix: prefix, + UdpAddr: iface.UDPAddr, + RecvSocket: net.UDPConn{}, + SocketChannel: make(chan bool), + State: false, + } + + err := createUDPConn(iface.UDPAddr, &i.RecvSocket) + if err != nil { + return errors.WithMessage(err, "Error creating UDP socket for interface->\t"+iface.Name) + } + go InterfaceListenerRoutine(i.RecvSocket, i.SocketChannel) + myInterfaces = append(myInterfaces, i) + + // TODO: (FOR HOSTS ONLY) + // add STATIC to routing table + if !static { + ifacePrefix := netip.MustParsePrefix("0.0.0.0/0") + routingTable[ifacePrefix] = Hop{STATIC_COST, iface.Name} + static = true + } + } + + // 2) initialize the neighbors connected to the node and into the routing table + for _, neighbor := range lnxConfig.Neighbors { + n := &Neighbor{ + VipAddr: neighbor.DestAddr, + UdpAddr: neighbor.UDPAddr, + SendSocket: net.UDPConn{}, + SocketChannel: make(chan bool), + } + + err := createUDPConn(neighbor.UDPAddr, &n.SendSocket) + if err != nil { + return errors.WithMessage(err, "Error creating UDP socket for neighbor->\t"+neighbor.DestAddr.String()) + } + + myNeighbors[neighbor.InterfaceName] = append(myNeighbors[neighbor.InterfaceName], n) + + // add to routing table + // TODO: REVISIT AND SEE IF "24" IS CORRECT + neighborPrefix := netip.PrefixFrom(neighbor.DestAddr, 24) + routingTable[neighborPrefix] = Hop{LOCAL_COST, neighbor.InterfaceName} + } + + return nil +} + +func InterfaceListenerRoutine(socket net.UDPConn, signal <-chan bool) { + isUp := false + closed := false + + // go routine that hangs on the recv + fmt.Println("MAKING GO ROUTINE TO LISTEN:\t", socket.LocalAddr().String()) + go func() { + defer func() { // on close, set isUp to false + fmt.Println("exiting go routine that listens on ", socket.LocalAddr().String()) + }() + + for { + if closed { // stop this go routine if channel is closed + return + } + if !isUp { // don't call the listeners if interface is down + continue + } + // TODO: remove these "training wheels" + time.Sleep(1 * time.Millisecond) + err := RecvIP(socket, &isUp) + if err != nil { + fmt.Println("Error receiving IP packet", err) + return + } + } + }() + + for { + select { + case sig, ok := <-signal: + if !ok { + fmt.Println("channel closed, exiting") + closed = true + return + } + fmt.Println("received isUP SIGNAL with value", sig) + isUp = sig + default: + continue + } + } +} + +func InterfaceUp(iface *Interface) { + iface.State = true + iface.SocketChannel <- true +} + +func InterfaceDown(iface *Interface) { + iface.SocketChannel <- false + iface.State = false +} + +func GetInterfaceByName(ifaceName string) (*Interface, error) { + for _, iface := range myInterfaces { + if iface.Name == ifaceName { + return iface, nil + } + } + + return nil, errors.Errorf("No interface with name %s", ifaceName) +} + +func GetNeighborsToInterface(ifaceName string) ([]*Neighbor, error) { + if neighbors, ok := myNeighbors[ifaceName]; ok { + return neighbors, nil + } + + return nil, errors.Errorf("No interface with name %s", ifaceName) +} + +func SprintInterfaces() string { + buf := "" + for _, iface := range myInterfaces { + buf += fmt.Sprintf("%s\t%s\t%t\n", iface.Name, iface.IpPrefix.String(), iface.State) + } + return buf +} + +func SprintNeighbors() string { + buf := "" + for ifaceName, neighbor := range myNeighbors { + for _, n := range neighbor { + buf += fmt.Sprintf("%s\t%s\t%s\n", ifaceName, n.UdpAddr.String(), n.VipAddr.String()) + } + } + return buf +} + +func SprintRoutingTable() string { + buf := "" + for prefix, hop := range routingTable { + buf += fmt.Sprintf("%s\t%s\t%d\n", prefix.String(), hop.VipAsStr, hop.Cost) + } + return buf +} + +func DebugNeighbors() { + for ifaceName, neighbor := range myNeighbors { + for _, n := range neighbor { + fmt.Printf("%s\t%s\t%s\n", ifaceName, n.UdpAddr.String(), n.VipAddr.String()) + } + } +} + +func CleanUp() { + fmt.Print("Cleaning up...\n") + // go through the interfaces, pop thread & close the UDP FDs + for _, iface := range myInterfaces { + if iface.SocketChannel != nil { + close(iface.SocketChannel) + } + err := iface.RecvSocket.Close() + if err != nil { + continue + } + } + + // go through the neighbors, pop thread & close the UDP FDs + for _, neighbor := range myNeighbors { + for _, n := range neighbor { + if n.SocketChannel != nil { + close(n.SocketChannel) + } + err := n.SendSocket.Close() + if err != nil { + continue + } + } + } + + // delete all the neighbors + myNeighbors = make(map[string][]*Neighbor) + // delete all the interfaces + myInterfaces = nil + // delete the routing table + routingTable = make(map[netip.Prefix]Hop) + + time.Sleep(5 * time.Millisecond) +} + +// TODO: have it take TTL so we can decrement it when forwarding +func SendIP(src Interface, dest Neighbor, protocolNum int, message []byte) error { + hdr := ipv4header.IPv4Header{ + Version: 4, + Len: 20, // Header length is always 20 when no IP options + TOS: 0, + TotalLen: ipv4header.HeaderLen + len(message), + ID: 0, + Flags: 0, + FragOff: 0, + TTL: 32, + Protocol: protocolNum, + Checksum: 0, // Should be 0 until checksum is computed + Src: src.IpPrefix.Addr(), + Dst: dest.VipAddr, + Options: []byte{}, + } + + // Assemble the header into a byte array + headerBytes, err := hdr.Marshal() + if err != nil { + return err + } + + // Compute the checksum (see below) + // Cast back to an int, which is what the Header structure expects + hdr.Checksum = int(ComputeChecksum(headerBytes)) + + headerBytes, err = hdr.Marshal() + if err != nil { + log.Fatalln("Error marshalling header: ", err) + } + + bytesToSend := make([]byte, 0, len(headerBytes)+len(message)) + bytesToSend = append(bytesToSend, headerBytes...) + bytesToSend = append(bytesToSend, []byte(message)...) + + // Send the message to the "link-layer" addr:port on UDP + listenAddr, err := net.ResolveUDPAddr("udp4", dest.UdpAddr.String()) + if err != nil { + return err + } + bytesWritten, err := dest.SendSocket.WriteToUDP(bytesToSend, listenAddr) + if err != nil { + return err + } + fmt.Printf("Sent %d bytes to %s\n", bytesWritten, listenAddr.String()) + + return nil +} + +func RecvIP(conn net.UDPConn, isOpen *bool) error { + buffer := make([]byte, MAX_IP_PACKET_SIZE) // TODO: fix wordking + + // Read on the UDP port + fmt.Println("wating to read from UDP socket") + _, sourceAddr, err := conn.ReadFromUDP(buffer) + if err != nil { + return err + } + + if !*isOpen { + return errors.New("interface is down") + } + + // Marshal the received byte array into a UDP header + // NOTE: This does not validate the checksum or check any fields + // (You'll need to do this part yourself) + hdr, err := ipv4header.ParseHeader(buffer) + if err != nil { + // What should you if the message fails to parse? + // Your node should not crash or exit when you get a bad message. + // Instead, simply drop the packet and return to processing. + fmt.Println("Error parsing header", err) + return err + } + + headerSize := hdr.Len + headerBytes := buffer[:headerSize] + checksumFromHeader := uint16(hdr.Checksum) + computedChecksum := ValidateChecksum(headerBytes, checksumFromHeader) + + var checksumState string + if computedChecksum == checksumFromHeader { + checksumState = "OK" + } else { + checksumState = "FAIL" + } + + // Next, get the message, which starts after the header + message := buffer[headerSize:] + + // Finally, print everything out + fmt.Printf("Received IP packet from %s\nHeader: %v\nChecksum: %s\nMessage: %s\n", + sourceAddr.String(), hdr, checksumState, string(message)) + + // TODO: handle the message + // 1) check if the TTL & checksum is valid + // 2) check if the message is for me, if so, sendUP (aka call the correct handler) + // if not, need to forward the packer to a neighbor or check the table + // after decrementing TTL and updating checksum + // 3) check if message is for a neighbor, if so, sendIP there + // 4) check if message is for a neighbor, if so, forward to the neighbor with that VIP + + return nil +} + +func ComputeChecksum(b []byte) uint16 { + checksum := header.Checksum(b, 0) + checksumInv := checksum ^ 0xffff + + return checksumInv +} + +func ValidateChecksum(b []byte, fromHeader uint16) uint16 { + checksum := header.Checksum(b, fromHeader) + + return checksum +} |