package websocket

import (
	"fmt"
	"io"
)

func Read(r io.Reader) ([]byte, error) {
	header := make([]byte, 2)
	_, err := io.ReadFull(r, header)
	if err != nil {
		return nil, err
	}

	fin := header[0]&0x80 != 0
	opcode := header[0] & 0x0F
	if opcode == 0x8 { // close frame
		return nil, io.EOF
	}
	if !fin || opcode != 1 {
		return nil, fmt.Errorf("unsupported WS frame")
	}

	mask := header[1]&0x80 != 0
	payloadLen := int(header[1] & 0x7F)

	if payloadLen == 126 {
		ext := make([]byte, 2)
		_, err := io.ReadFull(r, ext)
		if err != nil {
			return nil, err
		}
		payloadLen = int(ext[0])<<8 | int(ext[1])
	}

	maskKey := make([]byte, 4)
	if mask {
		_, err := io.ReadFull(r, maskKey)
		if err != nil {
			return nil, err
		}
	}

	payload := make([]byte, payloadLen)
	_, err = io.ReadFull(r, payload)
	if err != nil {
		return nil, err
	}

	if mask {
		for i := 0; i < payloadLen; i++ {
			payload[i] ^= maskKey[i%4]
		}
	}

	return payload, nil
}
