diff options
Diffstat (limited to 'vendor/github.com/google/gopacket/tcpassembly/tcpreader/reader_test.go')
-rw-r--r-- | vendor/github.com/google/gopacket/tcpassembly/tcpreader/reader_test.go | 129 |
1 files changed, 129 insertions, 0 deletions
diff --git a/vendor/github.com/google/gopacket/tcpassembly/tcpreader/reader_test.go b/vendor/github.com/google/gopacket/tcpassembly/tcpreader/reader_test.go new file mode 100644 index 0000000..7da9fd9 --- /dev/null +++ b/vendor/github.com/google/gopacket/tcpassembly/tcpreader/reader_test.go @@ -0,0 +1,129 @@ +// Copyright 2012 Google, Inc. All rights reserved. +// +// Use of this source code is governed by a BSD-style license +// that can be found in the LICENSE file in the root of the source +// tree. + +package tcpreader + +import ( + "bytes" + "fmt" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/google/gopacket/tcpassembly" + "io" + "net" + "testing" +) + +var netFlow gopacket.Flow + +func init() { + netFlow, _ = gopacket.FlowFromEndpoints( + layers.NewIPEndpoint(net.IP{1, 2, 3, 4}), + layers.NewIPEndpoint(net.IP{5, 6, 7, 8})) +} + +type readReturn struct { + data []byte + err error +} +type readSequence struct { + in []layers.TCP + want []readReturn +} +type testReaderFactory struct { + lossErrors bool + readSize int + ReaderStream + output chan []byte +} + +func (t *testReaderFactory) New(a, b gopacket.Flow) tcpassembly.Stream { + return &t.ReaderStream +} + +func testReadSequence(t *testing.T, lossErrors bool, readSize int, seq readSequence) { + f := &testReaderFactory{ReaderStream: NewReaderStream()} + f.ReaderStream.LossErrors = lossErrors + p := tcpassembly.NewStreamPool(f) + a := tcpassembly.NewAssembler(p) + buf := make([]byte, readSize) + go func() { + for i, test := range seq.in { + fmt.Println("Assembling", i) + a.Assemble(netFlow, &test) + fmt.Println("Assembly done") + } + }() + for i, test := range seq.want { + fmt.Println("Waiting for read", i) + n, err := f.Read(buf[:]) + fmt.Println("Got read") + if n != len(test.data) { + t.Errorf("test %d want %d bytes, got %d bytes", i, len(test.data), n) + } else if err != test.err { + t.Errorf("test %d want err %v, got err %v", i, test.err, err) + } else if !bytes.Equal(buf[:n], test.data) { + t.Errorf("test %d\nwant: %v\n got: %v\n", i, test.data, buf[:n]) + } + } + fmt.Println("All done reads") +} + +func TestRead(t *testing.T) { + testReadSequence(t, false, 10, readSequence{ + in: []layers.TCP{ + { + SYN: true, + SrcPort: 1, + DstPort: 2, + Seq: 1000, + BaseLayer: layers.BaseLayer{Payload: []byte{1, 2, 3}}, + }, + { + FIN: true, + SrcPort: 1, + DstPort: 2, + Seq: 1004, + }, + }, + want: []readReturn{ + {data: []byte{1, 2, 3}}, + {err: io.EOF}, + }, + }) +} + +func TestReadSmallChunks(t *testing.T) { + testReadSequence(t, false, 2, readSequence{ + in: []layers.TCP{ + { + SYN: true, + SrcPort: 1, + DstPort: 2, + Seq: 1000, + BaseLayer: layers.BaseLayer{Payload: []byte{1, 2, 3}}, + }, + { + FIN: true, + SrcPort: 1, + DstPort: 2, + Seq: 1004, + }, + }, + want: []readReturn{ + {data: []byte{1, 2}}, + {data: []byte{3}}, + {err: io.EOF}, + }, + }) +} + +func ExampleDiscardBytesToEOF() { + b := bytes.NewBuffer([]byte{1, 2, 3, 4, 5}) + fmt.Println(DiscardBytesToEOF(b)) + // Output: + // 5 +} |