refactor(countrw): move to iout

This commit is contained in:
sentriz
2022-04-12 23:25:54 +01:00
parent d7655cb9d1
commit fd211d706a
6 changed files with 132 additions and 27 deletions

View File

@@ -1,26 +0,0 @@
package countrw
import (
"io"
"sync/atomic"
)
type CountReader struct {
r io.Reader
c *uint64
}
func NewCountReader(r io.Reader) *CountReader {
return &CountReader{r: r, c: new(uint64)}
}
func (c *CountReader) Reset() { atomic.StoreUint64(c.c, 0) }
func (c *CountReader) Count() uint64 { return atomic.LoadUint64(c.c) }
func (c *CountReader) Read(p []byte) (int, error) {
n, err := c.r.Read(p)
atomic.AddUint64(c.c, uint64(n))
return n, err
}
var _ io.Reader = (*CountReader)(nil)

23
iout/copy_range.go Normal file
View File

@@ -0,0 +1,23 @@
package iout
import (
"errors"
"fmt"
"io"
)
func CopyRange(w io.Writer, r io.Reader, start, length int64) error {
if _, err := io.CopyN(io.Discard, r, start); err != nil && !errors.Is(err, io.EOF) {
return fmt.Errorf("discard %d: %w", start, err)
}
if length == 0 {
if _, err := io.Copy(w, r); err != nil && !errors.Is(err, io.EOF) {
return fmt.Errorf("direct copy: %w", err)
}
return nil
}
if _, err := io.CopyN(w, io.MultiReader(r, NewNullReader()), length); err != nil && !errors.Is(err, io.EOF) {
return fmt.Errorf("copy %d: %w", length, err)
}
return nil
}

42
iout/copy_range_test.go Normal file
View File

@@ -0,0 +1,42 @@
package iout_test
import (
"bytes"
"testing"
"github.com/matryer/is"
"go.senan.xyz/gonic/iout"
)
func TestCopyRange(t *testing.T) {
t.Parallel()
is := is.New(t)
realLength := 50
cr := func(start, length int64) []byte {
is.Helper()
var data []byte
for i := 0; i < realLength; i++ {
data = append(data, byte(i%10))
}
var buff bytes.Buffer
is.NoErr(iout.CopyRange(&buff, bytes.NewReader(data), start, length))
return buff.Bytes()
}
// range
is.Equal(len(cr(0, 50)), 50)
is.Equal(len(cr(10, 10)), 10)
is.Equal(cr(10, 10)[0], byte(0))
is.Equal(cr(10, 10)[5], byte(5))
is.Equal(cr(25, 35)[0], byte(5))
is.Equal(cr(25, 35)[5], byte(0))
// 0 padding
is.Equal(len(cr(0, 5000)), 5000)
is.Equal(cr(0, 5000)[50:], make([]byte, 5000-50))
// no bound
is.Equal(len(cr(0, 0)), 50)
is.Equal(len(cr(50, 0)), 0)
}

View File

@@ -1,10 +1,30 @@
package countrw package iout
import ( import (
"io" "io"
"sync/atomic" "sync/atomic"
) )
type CountReader struct {
r io.Reader
c *uint64
}
func NewCountReader(r io.Reader) *CountReader {
return &CountReader{r: r, c: new(uint64)}
}
func (c *CountReader) Reset() { atomic.StoreUint64(c.c, 0) }
func (c *CountReader) Count() uint64 { return atomic.LoadUint64(c.c) }
func (c *CountReader) Read(p []byte) (int, error) {
n, err := c.r.Read(p)
atomic.AddUint64(c.c, uint64(n))
return n, err
}
var _ io.Reader = (*CountReader)(nil)
type CountWriter struct { type CountWriter struct {
r io.Writer r io.Writer
c *uint64 c *uint64

18
iout/null_reader.go Normal file
View File

@@ -0,0 +1,18 @@
package iout
import "io"
type nullReader struct{}
func NewNullReader() io.Reader {
return &nullReader{}
}
func (*nullReader) Read(p []byte) (n int, err error) {
for b := range p {
p[b] = 0
}
return len(p), nil
}
var _ io.Reader = (*nullReader)(nil)

28
iout/tee_closer.go Normal file
View File

@@ -0,0 +1,28 @@
package iout
import "io"
type teeCloser struct {
r io.ReadCloser
w io.WriteCloser
}
func NewTeeCloser(r io.ReadCloser, w io.WriteCloser) io.ReadCloser {
return &teeCloser{r, w}
}
func (t *teeCloser) Read(p []byte) (int, error) {
n, err := t.r.Read(p)
if n > 0 {
if n, err := t.w.Write(p[:n]); err != nil {
return n, err
}
}
return n, err
}
func (t *teeCloser) Close() error {
t.r.Close()
t.w.Close()
return nil
}