diff options
Diffstat (limited to 'pkg/ipstack/ipstack.go')
-rw-r--r-- | pkg/ipstack/ipstack.go | 298 |
1 files changed, 298 insertions, 0 deletions
diff --git a/pkg/ipstack/ipstack.go b/pkg/ipstack/ipstack.go new file mode 100644 index 0000000..770fd17 --- /dev/null +++ b/pkg/ipstack/ipstack.go @@ -0,0 +1,298 @@ +package ipstack + +import ( + "fmt" + ipv4header "github.com/brown-csci1680/iptcp-headers" + "github.com/google/netstack/tcpip/header" + "github.com/pkg/errors" + "../../pkg/lnxconfig" + "log" + "net" + "net/netip" + "os" +) + +const ( + MAX_IP_PACKET_SIZE = 1400 +) + +type Interface struct { + Name string + AssignedIP netip.Addr + AssignedPrefix netip.Prefix + + UDPAddr netip.AddrPort + State bool + neighbors map[netip.AddrPort]netip.AddrPort +} + +// type Host struct { +// Interface Interface +// Neighbors []Neighbor +// } + +// type Router struct { +// Interfaces []Interface +// Neighbors []Neighbor +// RIPNeighbors []Neighbor +// } + + +type Neighbor struct{ + DestAddr netip.Addr + UDPAddr netip.AddrPort + + InterfaceName string +} + +type RIPMessage struct { + command uint8_t; + numEntries uint8_t; + entries []RIPEntry; +} + +type RIPEntry struct { + addr netip.Addr; + cost uint16_t; + mask netip.Prefix; +} + +myInterfaces := make([]Interface); +myNeighbors := make(map[string]Neighbor) +myRIPNeighbors := make(map[string]Neighbor) +protocolHandlers := make(map[uint16]HandlerFunc) +// routingTable := make(map[Address]Routing) + +func Initialize(config IpConfig) (error) { + if len(os.Args) != 2 { + fmt.Printf("Usage: %s <configFile>\n", os.Args[0]) + os.Exit(1) + } + fileName := os.Args[1] + + // Parse the file + lnxConfig, err := lnxconfig.ParseConfig(fileName) + if err != nil { + panic(err) + } + + // populate routing table??? + for _, iface := range lnxConfig.Interfaces { + myInterfaces = append(myInterfaces, Interface{iface.Name, iface.AssignedIP, iface.AssignedPrefix, iface.UDPAddr, 0}) + } + + for _, neighbor := range lnxConfig.Neighbors { + myNeighbors[neighbor.DestAddr.String()] = Neighbor{neighbor.DestAddr, neighbor.UDPAddr, neighbor.InterfaceName} + } + + // add RIP neighbors + for _, neighbor := range lnxConfig.RipNeighbors { + myRIPNeighbors[neighbor.DestAddr.String()] = Neighbor{neighbor.DestAddr, neighbor.UDPAddr, neighbor.InterfaceName} + } +} + +func ListerToInterfaces() { + for _, iface := range myInterfaces { + go RecvIp(iface) + } +} + +func RecvIp(iface Interface) (error) { + for { + buffer := make([]byte, MAX_IP_PACKET_SIZE) + _, sourceAddr, err := iface.udp.ReadFrom(buffer) + if err != nil { + log.Panicln("Error reading from UDP socket ", err) + } + + hdr, err := ipv4header.ParseHeader(buffer) + + if err != nil { + fmt.Println("Error parsing header", err) + continue + } + + headerSize := hdr.Len + headerBytes := buffer[:headerSize] + + checksumFromHeader := uint16(hdr.Checksum) + computedChecksum := ValidateChecksum(headerBytes, checksumFromHeader) + + var checksumState string + if computedChecksum == checksumFromHeader { + checksumState = "OK" + } else { + checksumState = "FAIL" + continue + } + + // check ttl + ttl := data[8] + if ttl == 0 { + fmt.Println("TTL is 0") + continue + } + + + destAddr := netip.AddrFrom(data[16:20]) + protocolNum := data[9] + + if destAddr == iface.addr { + // send to handler + protocolHandlers[protocolNum](data) + // message := buffer[headerSize:] + + // fmt.Printf("Received IP packet from %s\nHeader: %v\nChecksum: %s\nMessage: %s\n", + // sourceAddr.String(), hdr, checksumState, string(message)) + } else { + // decrement ttl and update checksum + data[8] = ttl - 1 + data[10] = 0 + data[11] = 0 + newChecksum := int(ComputeChecksum(data[:headerSize])) + data[10] = newChecksum >> 8 + data[11] = newChecksum & 0xff + + // check neighbors + for _, neighbor := range iface.neighbors { + if neighbor == destAddr { + // send to neighbor + // SendIp(destAddr, protocolNum, data) + } + } + + // check forwarding table + + } + + } +} + +func ValidateChecksum(b []byte, fromHeader uint16) uint16 { + checksum := header.Checksum(b, fromHeader) + + return checksum +} + +func SendIp(dst netip.Addr, port uint16, protocolNum uint16, data []byte, iface Interface) (error) { + bindLocalAddr, err := net.ResolveUDPAddr("udp4", iface.UDPAddr.String()) + if err != nil { + log.Panicln("Error resolving address: ", err) + } + + addrString := fmt.Sprintf("%s:%s", dst, port) + remoteAddr, err := net.ResolveUDPAddr("udp4", addrString) + if err != nil { + log.Panicln("Error resolving address: ", err) + } + + fmt.Printf("Sending to %s:%d\n", + remoteAddr.IP.String(), remoteAddr.Port) + + // Bind on the local UDP port: this sets the source port + // and creates a conn + conn, err := net.ListenUDP("udp4", bindLocalAddr) + if err != nil { + log.Panicln("Dial: ", err) + } + + // Start filling in the header + message := data[20:] + hdr := ipv4header.IPv4Header{ + Version: data[0] >> 4, + Len: 20, // Header length is always 20 when no IP options + TOS: data[1], + TotalLen: ipv4header.HeaderLen + len(message), + ID: data[4], + Flags: data[6] >> 5, + FragOff: data[6] & 0x1f, + TTL: data[8], + Protocol: data[9], + Checksum: 0, // Should be 0 until checksum is computed + Src: netip.MustParseAddr(iface.addr.String()), + Dst: netip.MustParseAddr(dst.String()), + Options: []byte{}, + } + + // Assemble the header into a byte array + headerBytes, err := hdr.Marshal() + if err != nil { + log.Fatalln("Error marshalling header: ", err) + } + + // Compute the checksum (see below) + // Cast back to an int, which is what the Header structure expects + hdr.Checksum = int(ComputeChecksum(headerBytes)) + + headerBytes, err = hdr.Marshal() + if err != nil { + log.Fatalln("Error marshalling header: ", err) + } + + bytesToSend := make([]byte, 0, len(headerBytes)+len(message)) + bytesToSend = append(bytesToSend, headerBytes...) + bytesToSend = append(bytesToSend, []byte(message)...) + + // Send the message to the "link-layer" addr:port on UDP + bytesWritten, err := conn.WriteToUDP(bytesToSend, remoteAddr) + if err != nil { + log.Panicln("Error writing to socket: ", err) + } + fmt.Printf("Sent %d bytes\n", bytesWritten) +} + +func ComputeChecksum(b []byte) uint16 { + checksum := header.Checksum(b, 0) + checksumInv := checksum ^ 0xffff + + return checksumInv +} + +func ForwardIP(data []byte) (error) { +} + +type HandlerFunc = func help(*Packet, []interface{}) (error) { + + // do smth with packet +} + +func AddRecvHandler(protocolNum uint8, callbackFunc HandlerFunc) (error) { + if protocolHandlers[protocolNum] != nil { + fmt.Printf("Warning: Handler for protocol %d already exists", protocolNum) + } + protocolHandlers[protocolNum] = callbackFunc + return nil +} + +func RemoveRecvHandler(protocolNum uint8) (error) { + // consider error + if protocolHandlers[protocolNum] == nil { + return errors.Errorf("No handler for protocol %d", protocolNum) + } + delete(protocolHandlers, protocolNum) + return nil +} + +// func routeRip(data []byte) (error) { +// // deconstruct packet +// newRIPMessage := RIPMessage{} +// newRIPMessage.command = data[0] +// newRIPMessage.numEntries = data[1] +// newRIPMessage.entries = make([]RIPEntry, newRIPMessage.numEntries) +// } + +func PrintNeighbors() { + for _, iface := range myNeighbors { + fmt.Printf("%s\n", iface.addr.String()) + } +} + +func PrintInterfaces() { + for _, iface := range myInterfaces { + fmt.Printf("%s\n", iface.addr.String()) + } +} +func GetNeighbors() ([]netip.Addr) { + return myNeighbors +} + |