Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 46 additions & 3 deletions zstd.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (
"bytes"
"errors"
"io/ioutil"
"runtime"
"sync"
"unsafe"
)

Expand Down Expand Up @@ -57,6 +59,38 @@ func cCompressBound(srcSize int) int {
return int(C.ZSTD_compressBound(C.size_t(srcSize)))
}

// Keep pools of reusable contexts and use them in calls to
// ZSTD_compressCCtx/ZSTD_decompressDCtx. Those functions reset
// session state so pooling context objects is safe and significantly
// reduces allocation churn.
type cctxWrapper struct {
cctx *C.ZSTD_CCtx
}

type dctxWrapper struct {
dctx *C.ZSTD_DCtx
}

var cctxPool = sync.Pool{
New: func() interface{} {
w := &cctxWrapper{cctx: C.ZSTD_createCCtx()}
runtime.SetFinalizer(w, func(w *cctxWrapper) {
C.ZSTD_freeCCtx(w.cctx)
})
return w
},
}

var dctxPool = sync.Pool{
New: func() interface{} {
w := &dctxWrapper{dctx: C.ZSTD_createDCtx()}
runtime.SetFinalizer(w, func(w *dctxWrapper) {
C.ZSTD_freeDCtx(w.dctx)
})
return w
},
}

// decompressSizeHint tries to give a hint on how much of the output buffer size we should have
// based on zstd frame descriptors. To prevent DOS from maliciously-created payloads, limit the size
func decompressSizeHint(src []byte) int {
Expand Down Expand Up @@ -100,26 +134,32 @@ func CompressLevel(dst, src []byte, level int) ([]byte, error) {
dst = make([]byte, bound)
}

w := cctxPool.Get().(*cctxWrapper)

// We need unsafe.Pointer(&src[0]) in the Cgo call to avoid "Go pointer to Go pointer" panics.
// This means we need to special case empty input. See:
// https://github.com/golang/go/issues/14210#issuecomment-346402945
var cWritten C.size_t
if len(src) == 0 {
cWritten = C.ZSTD_compress(
cWritten = C.ZSTD_compressCCtx(
w.cctx,
unsafe.Pointer(&dst[0]),
C.size_t(len(dst)),
unsafe.Pointer(nil),
C.size_t(0),
C.int(level))
} else {
cWritten = C.ZSTD_compress(
cWritten = C.ZSTD_compressCCtx(
w.cctx,
unsafe.Pointer(&dst[0]),
C.size_t(len(dst)),
unsafe.Pointer(&src[0]),
C.size_t(len(src)),
C.int(level))
}

cctxPool.Put(w)

written := int(cWritten)
// Check if the return is an Error code
if err := getError(written); err != nil {
Expand Down Expand Up @@ -165,10 +205,13 @@ func Decompress(dst, src []byte) ([]byte, error) {
// It returns the number of bytes copied and an error if any is encountered. If
// dst is too small, DecompressInto errors.
func DecompressInto(dst, src []byte) (int, error) {
written := int(C.ZSTD_decompress(
w := dctxPool.Get().(*dctxWrapper)
written := int(C.ZSTD_decompressDCtx(
w.dctx,
unsafe.Pointer(&dst[0]),
C.size_t(len(dst)),
unsafe.Pointer(&src[0]),
C.size_t(len(src))))
dctxPool.Put(w)
return written, getError(written)
}