diff --git a/zstd.go b/zstd.go index 8499bf1..b0c4faf 100644 --- a/zstd.go +++ b/zstd.go @@ -13,6 +13,8 @@ import ( "bytes" "errors" "io/ioutil" + "runtime" + "sync" "unsafe" ) @@ -57,6 +59,38 @@ func cCompressBound(srcSize int) int { return int(C.ZSTD_compressBound(C.size_t(srcSize))) } +// Keep pools of reusable contexts and use them in calls to +// ZSTD_compressCCtx/ZSTD_decompressDCtx. Those functions reset +// session state so pooling context objects is safe and significantly +// reduces allocation churn. +type cctxWrapper struct { + cctx *C.ZSTD_CCtx +} + +type dctxWrapper struct { + dctx *C.ZSTD_DCtx +} + +var cctxPool = sync.Pool{ + New: func() interface{} { + w := &cctxWrapper{cctx: C.ZSTD_createCCtx()} + runtime.SetFinalizer(w, func(w *cctxWrapper) { + C.ZSTD_freeCCtx(w.cctx) + }) + return w + }, +} + +var dctxPool = sync.Pool{ + New: func() interface{} { + w := &dctxWrapper{dctx: C.ZSTD_createDCtx()} + runtime.SetFinalizer(w, func(w *dctxWrapper) { + C.ZSTD_freeDCtx(w.dctx) + }) + return w + }, +} + // decompressSizeHint tries to give a hint on how much of the output buffer size we should have // based on zstd frame descriptors. To prevent DOS from maliciously-created payloads, limit the size func decompressSizeHint(src []byte) int { @@ -100,19 +134,23 @@ func CompressLevel(dst, src []byte, level int) ([]byte, error) { dst = make([]byte, bound) } + w := cctxPool.Get().(*cctxWrapper) + // We need unsafe.Pointer(&src[0]) in the Cgo call to avoid "Go pointer to Go pointer" panics. // This means we need to special case empty input. See: // https://github.com/golang/go/issues/14210#issuecomment-346402945 var cWritten C.size_t if len(src) == 0 { - cWritten = C.ZSTD_compress( + cWritten = C.ZSTD_compressCCtx( + w.cctx, unsafe.Pointer(&dst[0]), C.size_t(len(dst)), unsafe.Pointer(nil), C.size_t(0), C.int(level)) } else { - cWritten = C.ZSTD_compress( + cWritten = C.ZSTD_compressCCtx( + w.cctx, unsafe.Pointer(&dst[0]), C.size_t(len(dst)), unsafe.Pointer(&src[0]), @@ -120,6 +158,8 @@ func CompressLevel(dst, src []byte, level int) ([]byte, error) { C.int(level)) } + cctxPool.Put(w) + written := int(cWritten) // Check if the return is an Error code if err := getError(written); err != nil { @@ -165,10 +205,13 @@ func Decompress(dst, src []byte) ([]byte, error) { // It returns the number of bytes copied and an error if any is encountered. If // dst is too small, DecompressInto errors. func DecompressInto(dst, src []byte) (int, error) { - written := int(C.ZSTD_decompress( + w := dctxPool.Get().(*dctxWrapper) + written := int(C.ZSTD_decompressDCtx( + w.dctx, unsafe.Pointer(&dst[0]), C.size_t(len(dst)), unsafe.Pointer(&src[0]), C.size_t(len(src)))) + dctxPool.Put(w) return written, getError(written) }