|
|
|
@ -3,11 +3,15 @@ package main
|
|
|
|
|
import (
|
|
|
|
|
"bytes"
|
|
|
|
|
"encoding/binary"
|
|
|
|
|
"errors"
|
|
|
|
|
"fmt"
|
|
|
|
|
"io"
|
|
|
|
|
"io/fs"
|
|
|
|
|
"log"
|
|
|
|
|
"net"
|
|
|
|
|
"os"
|
|
|
|
|
"strings"
|
|
|
|
|
"syscall"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
const (
|
|
|
|
@ -57,63 +61,76 @@ func (tftp *tftpServer) listenAndServe() {
|
|
|
|
|
log.Printf("error while reading packet: '%v'\n", err)
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
err = tftp.handleConnection(addr, numRead, body)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Printf("got some error: '%v'\n", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
tftp.handleConnection(addr, numRead, body)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (tftp *tftpServer) handleConnection(addr net.Addr, numRead int, body []byte) error {
|
|
|
|
|
func (tftp *tftpServer) handleConnection(addr net.Addr, numRead int, body []byte) {
|
|
|
|
|
cli, ok := tftp.connections[addr.String()]
|
|
|
|
|
if !ok {
|
|
|
|
|
cli = newClient(addr)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
req, err := newRequest(numRead, body)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
err := func() error {
|
|
|
|
|
req, err := newRequest(numRead, body)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
err = tftp.handleRequest(cli, req)
|
|
|
|
|
if err != nil {
|
|
|
|
|
err = tftp.handleRequest(cli, req)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
resp := newResponse(cli, req)
|
|
|
|
|
err = tftp.handleResponse(cli, resp)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
_, err = tftp.sendResponse(cli, resp)
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
resp := newResponse(cli, req)
|
|
|
|
|
err = tftp.handleResponse(cli, resp)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
tftpErr, ok := err.(*tftpError)
|
|
|
|
|
if !ok {
|
|
|
|
|
log.Printf("Got unexpected error: %v\n", err)
|
|
|
|
|
tftpErr = newTFTPError(ecNDEF, "Unexpected error.")
|
|
|
|
|
}
|
|
|
|
|
_, err = tftp.sendError(cli, tftpErr)
|
|
|
|
|
if err != nil {
|
|
|
|
|
panic(err)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
_, err = tftp.sendResponse(cli, resp)
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (tftp *tftpServer) handleRequest(cli *client, req *request) error {
|
|
|
|
|
// check for filename (rrq)
|
|
|
|
|
// check for filename (wrq)
|
|
|
|
|
// check for disk space
|
|
|
|
|
// check for file exist (wrq)
|
|
|
|
|
// access violation ?
|
|
|
|
|
// no such user ?
|
|
|
|
|
|
|
|
|
|
// checking for illegal operations
|
|
|
|
|
if req.opcode < opRRQ || req.opcode > opERROR {
|
|
|
|
|
return fmt.Errorf("Illegal operation!\n")
|
|
|
|
|
return newTFTPError(ecILL)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// checking for unknown client
|
|
|
|
|
_, clientExists := tftp.connections[cli.tid.String()]
|
|
|
|
|
if !clientExists && req.opcode != opRRQ && req.opcode != opWRQ {
|
|
|
|
|
return fmt.Errorf("Unknown client!\n")
|
|
|
|
|
return newTFTPError(ecUTID)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO: handle this properly (probably will need to close 'connection')
|
|
|
|
|
if req.opcode == opERROR {
|
|
|
|
|
log.Printf("Got error from client: '%s' (%v)\n", req.errorMessage, req.number)
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if !clientExists {
|
|
|
|
|
log.Printf("Got new client: %v\n", cli.tid.String())
|
|
|
|
|
|
|
|
|
|
err := cli.prepareFromRequest(req)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
@ -121,9 +138,13 @@ func (tftp *tftpServer) handleRequest(cli *client, req *request) error {
|
|
|
|
|
tftp.connections[cli.tid.String()] = cli
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO: last data packet, close the client!
|
|
|
|
|
if req.opcode == opDATA {
|
|
|
|
|
_, err := io.Copy(cli.file, bytes.NewReader(req.body))
|
|
|
|
|
if err != nil {
|
|
|
|
|
if errors.Is(err, syscall.ENOSPC) {
|
|
|
|
|
err = newTFTPError(ecDSK)
|
|
|
|
|
}
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -138,6 +159,7 @@ func (tftp *tftpServer) handleResponse(cli *client, resp *response) error {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
if cli.bytesLeft <= 0 {
|
|
|
|
|
log.Printf("Client '%v' has received a file.\n", cli.tid.String())
|
|
|
|
|
cli.file.Close()
|
|
|
|
|
delete(tftp.connections, cli.tid.String())
|
|
|
|
|
}
|
|
|
|
@ -148,6 +170,11 @@ func (tftp *tftpServer) handleResponse(cli *client, resp *response) error {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (tftp *tftpServer) sendError(cli *client, err *tftpError) (int, error) {
|
|
|
|
|
log.Println(err)
|
|
|
|
|
return tftp.sendResponse(cli, &response{opERROR, uint16(err.code), toCString(err.message.Error())})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (tftp *tftpServer) sendResponse(cli *client, resp *response) (int, error) {
|
|
|
|
|
header := []byte{0x0, byte(resp.opcode), 0x0, 0x0}
|
|
|
|
|
binary.BigEndian.PutUint16(header[2:], resp.number)
|
|
|
|
@ -171,12 +198,24 @@ func (cli *client) prepareFromRequest(req *request) error {
|
|
|
|
|
var err error
|
|
|
|
|
var f *os.File
|
|
|
|
|
|
|
|
|
|
// TODO: clean path to filename
|
|
|
|
|
if req.opcode == opRRQ {
|
|
|
|
|
f, err = os.Open(req.filename)
|
|
|
|
|
} else {
|
|
|
|
|
if _, err := os.Stat(req.filename); !errors.Is(err, fs.ErrNotExist) {
|
|
|
|
|
return newTFTPError(ecFEX)
|
|
|
|
|
}
|
|
|
|
|
f, err = os.Create(req.filename)
|
|
|
|
|
}
|
|
|
|
|
if err != nil {
|
|
|
|
|
switch {
|
|
|
|
|
case errors.Is(err, fs.ErrNotExist):
|
|
|
|
|
err = newTFTPError(ecFNF)
|
|
|
|
|
case errors.Is(err, fs.ErrPermission):
|
|
|
|
|
err = newTFTPError(ecACV)
|
|
|
|
|
case errors.Is(err, syscall.ENOSPC):
|
|
|
|
|
err = newTFTPError(ecDSK)
|
|
|
|
|
}
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -226,7 +265,7 @@ func newRequest(numRead int, body []byte) (*request, error) {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
if req.mode != "octet" {
|
|
|
|
|
return nil, fmt.Errorf("Incorrect mode '%v'. This server only supports 'octet' mode.\n", req.mode)
|
|
|
|
|
return nil, newTFTPError(ecNDEF, "Incorrect mode '%v'. This server supports only 'octet' mode.", req.mode)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
req.body = req.body[n:]
|
|
|
|
@ -270,16 +309,30 @@ func newResponse(cli *client, req *request) *response {
|
|
|
|
|
return resp
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type operation byte
|
|
|
|
|
type tftpError struct {
|
|
|
|
|
code errorCode
|
|
|
|
|
message error
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const (
|
|
|
|
|
opUNK operation = iota
|
|
|
|
|
opRRQ
|
|
|
|
|
opWRQ
|
|
|
|
|
opDATA
|
|
|
|
|
opACK
|
|
|
|
|
opERROR
|
|
|
|
|
)
|
|
|
|
|
func newTFTPError(code errorCode, clientMessage ...string) *tftpError {
|
|
|
|
|
if code > ecNOUS {
|
|
|
|
|
code = ecNDEF
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
message := tftpErrors[code]
|
|
|
|
|
if code == ecNDEF {
|
|
|
|
|
message = errors.New(strings.Join(clientMessage, " "))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return &tftpError{
|
|
|
|
|
code: code,
|
|
|
|
|
message: message,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (err *tftpError) Error() string {
|
|
|
|
|
return fmt.Sprintf("TFTP Error (%v): %v", err.code, err.message)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type errorCode uint16
|
|
|
|
|
|
|
|
|
@ -294,26 +347,27 @@ const (
|
|
|
|
|
ecNOUS
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
var errorMessages = map[errorCode]string{
|
|
|
|
|
ecNDEF: "",
|
|
|
|
|
ecFNF: "File not found.",
|
|
|
|
|
ecACV: "Access violation.",
|
|
|
|
|
ecDSK: "Disk full or allocation exceeded.",
|
|
|
|
|
ecILL: "Illegal TFTP operation.",
|
|
|
|
|
ecUTID: "Unknown transfer ID.",
|
|
|
|
|
ecFEX: "File already exists.",
|
|
|
|
|
ecNOUS: "No such user.",
|
|
|
|
|
var tftpErrors = [...]error{
|
|
|
|
|
ecNDEF: errors.New(""),
|
|
|
|
|
ecFNF: errors.New("File not found."),
|
|
|
|
|
ecACV: errors.New("Access violation."),
|
|
|
|
|
ecDSK: errors.New("Disk full or allocation exceeded."),
|
|
|
|
|
ecILL: errors.New("Illegal TFTP operation."),
|
|
|
|
|
ecUTID: errors.New("Unknown transfer ID."),
|
|
|
|
|
ecFEX: errors.New("File already exists."),
|
|
|
|
|
ecNOUS: errors.New("No such user."),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
func sendError(addr net.Addr, err errorCode, optMessage ...string) (int, error) {
|
|
|
|
|
message := errorMessages[err]
|
|
|
|
|
if err == ecNDEF {
|
|
|
|
|
message = strings.Join(optMessage, "\n")
|
|
|
|
|
}
|
|
|
|
|
return sendResponse(addr, &response{opERROR, uint16(err), toCString(message)})
|
|
|
|
|
}
|
|
|
|
|
*/
|
|
|
|
|
type operation byte
|
|
|
|
|
|
|
|
|
|
const (
|
|
|
|
|
opUNK operation = iota
|
|
|
|
|
opRRQ
|
|
|
|
|
opWRQ
|
|
|
|
|
opDATA
|
|
|
|
|
opACK
|
|
|
|
|
opERROR
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
func toCString(src string) []byte {
|
|
|
|
|
return append([]byte(src), 0x0)
|
|
|
|
|