diff --git a/countrw/cr.go b/countrw/cr.go deleted file mode 100644 index 817dd25..0000000 --- a/countrw/cr.go +++ /dev/null @@ -1,26 +0,0 @@ -package countrw - -import ( - "io" - "sync/atomic" -) - -type CountReader struct { - r io.Reader - c *uint64 -} - -func NewCountReader(r io.Reader) *CountReader { - return &CountReader{r: r, c: new(uint64)} -} - -func (c *CountReader) Reset() { atomic.StoreUint64(c.c, 0) } -func (c *CountReader) Count() uint64 { return atomic.LoadUint64(c.c) } - -func (c *CountReader) Read(p []byte) (int, error) { - n, err := c.r.Read(p) - atomic.AddUint64(c.c, uint64(n)) - return n, err -} - -var _ io.Reader = (*CountReader)(nil) diff --git a/iout/copy_range.go b/iout/copy_range.go new file mode 100644 index 0000000..f4f1f04 --- /dev/null +++ b/iout/copy_range.go @@ -0,0 +1,23 @@ +package iout + +import ( + "errors" + "fmt" + "io" +) + +func CopyRange(w io.Writer, r io.Reader, start, length int64) error { + if _, err := io.CopyN(io.Discard, r, start); err != nil && !errors.Is(err, io.EOF) { + return fmt.Errorf("discard %d: %w", start, err) + } + if length == 0 { + if _, err := io.Copy(w, r); err != nil && !errors.Is(err, io.EOF) { + return fmt.Errorf("direct copy: %w", err) + } + return nil + } + if _, err := io.CopyN(w, io.MultiReader(r, NewNullReader()), length); err != nil && !errors.Is(err, io.EOF) { + return fmt.Errorf("copy %d: %w", length, err) + } + return nil +} diff --git a/iout/copy_range_test.go b/iout/copy_range_test.go new file mode 100644 index 0000000..b6a1867 --- /dev/null +++ b/iout/copy_range_test.go @@ -0,0 +1,42 @@ +package iout_test + +import ( + "bytes" + "testing" + + "github.com/matryer/is" + "go.senan.xyz/gonic/iout" +) + +func TestCopyRange(t *testing.T) { + t.Parallel() + is := is.New(t) + + realLength := 50 + cr := func(start, length int64) []byte { + is.Helper() + var data []byte + for i := 0; i < realLength; i++ { + data = append(data, byte(i%10)) + } + var buff bytes.Buffer + is.NoErr(iout.CopyRange(&buff, bytes.NewReader(data), start, length)) + return buff.Bytes() + } + + // range + is.Equal(len(cr(0, 50)), 50) + is.Equal(len(cr(10, 10)), 10) + is.Equal(cr(10, 10)[0], byte(0)) + is.Equal(cr(10, 10)[5], byte(5)) + is.Equal(cr(25, 35)[0], byte(5)) + is.Equal(cr(25, 35)[5], byte(0)) + + // 0 padding + is.Equal(len(cr(0, 5000)), 5000) + is.Equal(cr(0, 5000)[50:], make([]byte, 5000-50)) + + // no bound + is.Equal(len(cr(0, 0)), 50) + is.Equal(len(cr(50, 0)), 0) +} diff --git a/countrw/cw.go b/iout/countrw.go similarity index 51% rename from countrw/cw.go rename to iout/countrw.go index 5d27d9c..35a9364 100644 --- a/countrw/cw.go +++ b/iout/countrw.go @@ -1,10 +1,30 @@ -package countrw +package iout import ( "io" "sync/atomic" ) +type CountReader struct { + r io.Reader + c *uint64 +} + +func NewCountReader(r io.Reader) *CountReader { + return &CountReader{r: r, c: new(uint64)} +} + +func (c *CountReader) Reset() { atomic.StoreUint64(c.c, 0) } +func (c *CountReader) Count() uint64 { return atomic.LoadUint64(c.c) } + +func (c *CountReader) Read(p []byte) (int, error) { + n, err := c.r.Read(p) + atomic.AddUint64(c.c, uint64(n)) + return n, err +} + +var _ io.Reader = (*CountReader)(nil) + type CountWriter struct { r io.Writer c *uint64 diff --git a/iout/null_reader.go b/iout/null_reader.go new file mode 100644 index 0000000..d7477d6 --- /dev/null +++ b/iout/null_reader.go @@ -0,0 +1,18 @@ +package iout + +import "io" + +type nullReader struct{} + +func NewNullReader() io.Reader { + return &nullReader{} +} + +func (*nullReader) Read(p []byte) (n int, err error) { + for b := range p { + p[b] = 0 + } + return len(p), nil +} + +var _ io.Reader = (*nullReader)(nil) diff --git a/iout/tee_closer.go b/iout/tee_closer.go new file mode 100644 index 0000000..92cfea6 --- /dev/null +++ b/iout/tee_closer.go @@ -0,0 +1,28 @@ +package iout + +import "io" + +type teeCloser struct { + r io.ReadCloser + w io.WriteCloser +} + +func NewTeeCloser(r io.ReadCloser, w io.WriteCloser) io.ReadCloser { + return &teeCloser{r, w} +} + +func (t *teeCloser) Read(p []byte) (int, error) { + n, err := t.r.Read(p) + if n > 0 { + if n, err := t.w.Write(p[:n]); err != nil { + return n, err + } + } + return n, err +} + +func (t *teeCloser) Close() error { + t.r.Close() + t.w.Close() + return nil +}