blob: 4cba1ed2b1620a9c041d88af2949138a59da8803 [file] [log] [blame]
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"os"
"runtime"
"sync"
"testing"
"time"
)
const (
newton = "../testdata/Isaac.Newton-Opticks.txt"
newtonLen = 567198
newtonSHA256 = "d4a9ac22462b35e7821a4f2706c211093da678620a8f9997989ee7cf8d507bbd"
)
func TestSendfile(t *testing.T) {
ln := newLocalListener(t, "tcp")
defer ln.Close()
errc := make(chan error, 1)
go func(ln Listener) {
// Wait for a connection.
conn, err := ln.Accept()
if err != nil {
errc <- err
close(errc)
return
}
go func() {
defer close(errc)
defer conn.Close()
f, err := os.Open(newton)
if err != nil {
errc <- err
return
}
defer f.Close()
// Return file data using io.Copy, which should use
// sendFile if available.
sbytes, err := io.Copy(conn, f)
if err != nil {
errc <- err
return
}
if sbytes != newtonLen {
errc <- fmt.Errorf("sent %d bytes; expected %d", sbytes, newtonLen)
return
}
}()
}(ln)
// Connect to listener to retrieve file and verify digest matches
// expected.
c, err := Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatal(err)
}
defer c.Close()
h := sha256.New()
rbytes, err := io.Copy(h, c)
if err != nil {
t.Error(err)
}
if rbytes != newtonLen {
t.Errorf("received %d bytes; expected %d", rbytes, newtonLen)
}
if res := hex.EncodeToString(h.Sum(nil)); res != newtonSHA256 {
t.Error("retrieved data hash did not match")
}
for err := range errc {
t.Error(err)
}
}
func TestSendfileParts(t *testing.T) {
ln := newLocalListener(t, "tcp")
defer ln.Close()
errc := make(chan error, 1)
go func(ln Listener) {
// Wait for a connection.
conn, err := ln.Accept()
if err != nil {
errc <- err
close(errc)
return
}
go func() {
defer close(errc)
defer conn.Close()
f, err := os.Open(newton)
if err != nil {
errc <- err
return
}
defer f.Close()
for i := 0; i < 3; i++ {
// Return file data using io.CopyN, which should use
// sendFile if available.
_, err = io.CopyN(conn, f, 3)
if err != nil {
errc <- err
return
}
}
}()
}(ln)
c, err := Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatal(err)
}
defer c.Close()
buf := new(bytes.Buffer)
buf.ReadFrom(c)
if want, have := "Produced ", buf.String(); have != want {
t.Errorf("unexpected server reply %q, want %q", have, want)
}
for err := range errc {
t.Error(err)
}
}
func TestSendfileSeeked(t *testing.T) {
ln := newLocalListener(t, "tcp")
defer ln.Close()
const seekTo = 65 << 10
const sendSize = 10 << 10
errc := make(chan error, 1)
go func(ln Listener) {
// Wait for a connection.
conn, err := ln.Accept()
if err != nil {
errc <- err
close(errc)
return
}
go func() {
defer close(errc)
defer conn.Close()
f, err := os.Open(newton)
if err != nil {
errc <- err
return
}
defer f.Close()
if _, err := f.Seek(seekTo, io.SeekStart); err != nil {
errc <- err
return
}
_, err = io.CopyN(conn, f, sendSize)
if err != nil {
errc <- err
return
}
}()
}(ln)
c, err := Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatal(err)
}
defer c.Close()
buf := new(bytes.Buffer)
buf.ReadFrom(c)
if buf.Len() != sendSize {
t.Errorf("Got %d bytes; want %d", buf.Len(), sendSize)
}
for err := range errc {
t.Error(err)
}
}
// Test that sendfile doesn't put a pipe into blocking mode.
func TestSendfilePipe(t *testing.T) {
switch runtime.GOOS {
case "plan9", "windows", "js", "wasip1":
// These systems don't support deadlines on pipes.
t.Skipf("skipping on %s", runtime.GOOS)
}
t.Parallel()
ln := newLocalListener(t, "tcp")
defer ln.Close()
r, w, err := os.Pipe()
if err != nil {
t.Fatal(err)
}
defer w.Close()
defer r.Close()
copied := make(chan bool)
var wg sync.WaitGroup
wg.Add(1)
go func() {
// Accept a connection and copy 1 byte from the read end of
// the pipe to the connection. This will call into sendfile.
defer wg.Done()
conn, err := ln.Accept()
if err != nil {
t.Error(err)
return
}
defer conn.Close()
_, err = io.CopyN(conn, r, 1)
if err != nil {
t.Error(err)
return
}
// Signal the main goroutine that we've copied the byte.
close(copied)
}()
wg.Add(1)
go func() {
// Write 1 byte to the write end of the pipe.
defer wg.Done()
_, err := w.Write([]byte{'a'})
if err != nil {
t.Error(err)
}
}()
wg.Add(1)
go func() {
// Connect to the server started two goroutines up and
// discard any data that it writes.
defer wg.Done()
conn, err := Dial("tcp", ln.Addr().String())
if err != nil {
t.Error(err)
return
}
defer conn.Close()
io.Copy(io.Discard, conn)
}()
// Wait for the byte to be copied, meaning that sendfile has
// been called on the pipe.
<-copied
// Set a very short deadline on the read end of the pipe.
if err := r.SetDeadline(time.Now().Add(time.Microsecond)); err != nil {
t.Fatal(err)
}
wg.Add(1)
go func() {
// Wait for much longer than the deadline and write a byte
// to the pipe.
defer wg.Done()
time.Sleep(50 * time.Millisecond)
w.Write([]byte{'b'})
}()
// If this read does not time out, the pipe was incorrectly
// put into blocking mode.
_, err = r.Read(make([]byte, 1))
if err == nil {
t.Error("Read did not time out")
} else if !os.IsTimeout(err) {
t.Errorf("got error %v, expected a time out", err)
}
wg.Wait()
}
// Issue 43822: tests that returns EOF when conn write timeout.
func TestSendfileOnWriteTimeoutExceeded(t *testing.T) {
ln := newLocalListener(t, "tcp")
defer ln.Close()
errc := make(chan error, 1)
go func(ln Listener) (retErr error) {
defer func() {
errc <- retErr
close(errc)
}()
conn, err := ln.Accept()
if err != nil {
return err
}
defer conn.Close()
// Set the write deadline in the past(1h ago). It makes
// sure that it is always write timeout.
if err := conn.SetWriteDeadline(time.Now().Add(-1 * time.Hour)); err != nil {
return err
}
f, err := os.Open(newton)
if err != nil {
return err
}
defer f.Close()
_, err = io.Copy(conn, f)
if errors.Is(err, os.ErrDeadlineExceeded) {
return nil
}
if err == nil {
err = fmt.Errorf("expected ErrDeadlineExceeded, but got nil")
}
return err
}(ln)
conn, err := Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
n, err := io.Copy(io.Discard, conn)
if err != nil {
t.Fatalf("expected nil error, but got %v", err)
}
if n != 0 {
t.Fatalf("expected receive zero, but got %d byte(s)", n)
}
if err := <-errc; err != nil {
t.Fatal(err)
}
}
func BenchmarkSendfileZeroBytes(b *testing.B) {
var (
wg sync.WaitGroup
ctx, cancel = context.WithCancel(context.Background())
)
defer wg.Wait()
ln := newLocalListener(b, "tcp")
defer ln.Close()
tempFile, err := os.CreateTemp(b.TempDir(), "test.txt")
if err != nil {
b.Fatalf("failed to create temp file: %v", err)
}
defer tempFile.Close()
fileName := tempFile.Name()
dataSize := b.N
wg.Add(1)
go func(f *os.File) {
defer wg.Done()
for i := 0; i < dataSize; i++ {
if _, err := f.Write([]byte{1}); err != nil {
b.Errorf("failed to write: %v", err)
return
}
if i%1000 == 0 {
f.Sync()
}
}
}(tempFile)
b.ResetTimer()
b.ReportAllocs()
wg.Add(1)
go func(ln Listener, fileName string) {
defer wg.Done()
conn, err := ln.Accept()
if err != nil {
b.Errorf("failed to accept: %v", err)
return
}
defer conn.Close()
f, err := os.OpenFile(fileName, os.O_RDONLY, 0660)
if err != nil {
b.Errorf("failed to open file: %v", err)
return
}
defer f.Close()
for {
if ctx.Err() != nil {
return
}
if _, err := io.Copy(conn, f); err != nil {
b.Errorf("failed to copy: %v", err)
return
}
}
}(ln, fileName)
conn, err := Dial("tcp", ln.Addr().String())
if err != nil {
b.Fatalf("failed to dial: %v", err)
}
defer conn.Close()
n, err := io.CopyN(io.Discard, conn, int64(dataSize))
if err != nil {
b.Fatalf("failed to copy: %v", err)
}
if n != int64(dataSize) {
b.Fatalf("expected %d copied bytes, but got %d", dataSize, n)
}
cancel()
}