This commit is contained in:
Emad Elsaid 2026-04-22 21:54:36 -03:00 committed by GitHub
commit 75f61bb6fd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 266 additions and 409 deletions

View file

@ -75,6 +75,8 @@ type Causal struct {
backend ml.Backend
ctxs map[int]ml.Context
keys, values map[int]ml.Tensor
maskBuf []float32
}
type cacheCell struct {
@ -365,7 +367,14 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
length := c.curCellRange.max - c.curCellRange.min + 1
mask := make([]float32, c.curBatchSize*length)
size := c.curBatchSize * length
if cap(c.maskBuf) < size {
c.maskBuf = make([]float32, size)
}
mask := c.maskBuf[:size]
for i := range mask {
mask[i] = 0
}
for i := range c.curBatchSize {
enabled := !slices.Contains(c.opts.Except, i)

View file

@ -465,16 +465,16 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
context.Forward(out, mask).Compute(out, mask)
if !slices.Equal(out.Floats(), test.expected) {
t.Errorf("TestCache: have %v; want %v", out.Floats(), test.expected)
if !slices.Equal(out.Floats(nil), test.expected) {
t.Errorf("TestCache: have %v; want %v", out.Floats(nil), test.expected)
}
if !slices.Equal(out.Shape(), test.expectedShape) {
t.Errorf("TestCache: has shape %v; want %v", out.Shape(), test.expectedShape)
}
if !slices.Equal(mask.Floats(), test.expectedMask) {
t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(), test.expectedMask)
if !slices.Equal(mask.Floats(nil), test.expectedMask) {
t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(nil), test.expectedMask)
}
})
}
@ -734,10 +734,13 @@ func (t *testTensor) DType() ml.DType {
return t.dtype
}
func (t *testTensor) Floats() []float32 {
out := make([]float32, len(t.data))
copy(out, t.data)
return out
func (t *testTensor) Floats(dst []float32) []float32 {
if cap(dst) < len(t.data) {
dst = make([]float32, len(t.data))
}
dst = dst[:len(t.data)]
copy(dst, t.data)
return dst
}
func (t *testTensor) Neg(ctx ml.Context) ml.Tensor {

View file

@ -136,7 +136,7 @@ type Tensor interface {
Cast(ctx Context, dtype DType) Tensor
Bytes() []byte
Floats() []float32
Floats(dst []float32) []float32
BackendGet() []float32
FromBytes([]byte)

View file

@ -44,6 +44,17 @@ var (
backends map[C.ggml_backend_dev_t]C.ggml_backend_t
)
var tensorPool = sync.Pool{New: func() any { return new(Tensor) }}
func poolTensor(ctx *Context, b *Backend, t *C.struct_ggml_tensor) *Tensor {
r := tensorPool.Get().(*Tensor)
r.b, r.t, r.sync = b, t, nil
if ctx != nil && ctx.wrappers != nil {
*ctx.wrappers = append(*ctx.wrappers, r)
}
return r
}
var initDevices = sync.OnceFunc(func() {
ggml.OnceLoad()
@ -670,6 +681,7 @@ func (b *Backend) NewContextSize(n int) ml.Context {
}
var allocatedBuffers []C.ggml_backend_buffer_t
var wrappers []*Tensor
return &Context{
b: b,
@ -679,6 +691,7 @@ func (b *Backend) NewContextSize(n int) ml.Context {
no_alloc: true,
}),
allocatedBuffers: &allocatedBuffers,
wrappers: &wrappers,
layer: -1,
}
}
@ -757,6 +770,9 @@ type Context struct {
// maxGraphNodes is the maximum allowed number of graph nodes in this context
maxGraphNodes int
// wrappers tracks pooled Tensor wrappers so they can be returned on Close
wrappers *[]*Tensor
// layer is the graph layer that this context is allocating for - assumed to be cache
layer int
}
@ -769,6 +785,7 @@ func (c *Context) Input() ml.Context {
buft: c.b.input,
allocatedBuffers: c.allocatedBuffers,
maxGraphNodes: c.maxGraphNodes,
wrappers: c.wrappers,
layer: -1,
}
}
@ -784,6 +801,7 @@ func (c *Context) Layer(i int) ml.Context {
buft: layer.bt,
allocatedBuffers: c.allocatedBuffers,
maxGraphNodes: c.maxGraphNodes,
wrappers: c.wrappers,
layer: i,
}
}
@ -895,7 +913,7 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) *Tensor {
if len(shape) < 1 || shape[0] == 0 {
var shape C.int64_t = 0
return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}
return poolTensor(c, c.b, C.ggml_new_tensor(c.ctx, cdtype, 1, &shape))
} else if len(shape) > 4 {
panic("unsupported number of dimensions")
}
@ -920,7 +938,7 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) *Tensor {
*c.allocatedBuffers = append(*c.allocatedBuffers, b)
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
return &Tensor{b: c.b, t: t}
return poolTensor(c, c.b, t)
}
func (c *Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
@ -988,10 +1006,7 @@ func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
switch dtype {
case ml.DTypeF32:
// ggml_arange creates a float32 tensor
return &Tensor{
b: c.b,
t: C.ggml_arange(c.ctx, C.float(start), C.float(stop), C.float(step)),
}
return poolTensor(&c, c.b, C.ggml_arange(c.ctx, C.float(start), C.float(stop), C.float(step)))
case ml.DTypeI32:
// ggml_cast does not support float32 to int32 conversion
arange := make([]int32, 0, int((stop-start)/step))
@ -1013,6 +1028,14 @@ func (c *Context) Close() {
*c.allocatedBuffers = nil
C.ggml_free(c.ctx)
if c.wrappers != nil {
for _, w := range *c.wrappers {
w.b, w.t, w.sync = nil, nil, nil
tensorPool.Put(w)
}
*c.wrappers = (*c.wrappers)[:0]
}
}
}
@ -1058,15 +1081,18 @@ func (t *Tensor) Bytes() (data []byte) {
return
}
func (t *Tensor) Floats() (data []float32) {
if t.sync != nil {
data = make([]float32, C.ggml_nelements(t.t))
t.sync()
C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
func (t *Tensor) Floats(dst []float32) []float32 {
if t.sync == nil {
return dst[:0]
}
return
n := int(C.ggml_nelements(t.t))
if cap(dst) < n {
dst = make([]float32, n)
}
dst = dst[:n]
t.sync()
C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&dst[0]), 0, C.ggml_nbytes(t.t))
return dst
}
func (t *Tensor) BackendGet() []float32 {
@ -1145,24 +1171,15 @@ func ggmlDType(dtype ml.DType) uint32 {
}
func (t *Tensor) Cast(ctx ml.Context, dtype ml.DType) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_cast(ctx.(*Context).ctx, t.t, ggmlDType(dtype)),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_cast(ctx.(*Context).ctx, t.t, ggmlDType(dtype)))
}
func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t))
}
func (t *Tensor) Sub(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_sub(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_sub(ctx.(*Context).ctx, t.t, t2.(*Tensor).t))
}
func (t *Tensor) Repeat(ctx ml.Context, dim, n int) ml.Tensor {
@ -1180,10 +1197,7 @@ func (t *Tensor) Repeat(ctx ml.Context, dim, n int) ml.Tensor {
}
tmpl := C.ggml_new_tensor(ctx.(*Context).ctx, t.t._type, C.int(len(shape)), unsafe.SliceData(shape))
return &Tensor{
b: t.b,
t: C.ggml_repeat(ctx.(*Context).ctx, t.t, tmpl),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_repeat(ctx.(*Context).ctx, t.t, tmpl))
}
func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
@ -1195,10 +1209,7 @@ func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
}
func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)))
}
func (t *Tensor) Contiguous(ctx ml.Context, shape ...int) ml.Tensor {
@ -1206,49 +1217,29 @@ func (t *Tensor) Contiguous(ctx ml.Context, shape ...int) ml.Tensor {
inferShape(t, shape)
}
c := ctx.(*Context)
switch len(shape) {
case 0:
return &Tensor{
b: t.b,
t: C.ggml_cont(ctx.(*Context).ctx, t.t),
}
return poolTensor(c, t.b, C.ggml_cont(c.ctx, t.t))
case 1:
return &Tensor{
b: t.b,
t: C.ggml_cont_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
}
return poolTensor(c, t.b, C.ggml_cont_1d(c.ctx, t.t, C.int64_t(shape[0])))
case 2:
return &Tensor{
b: t.b,
t: C.ggml_cont_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])),
}
return poolTensor(c, t.b, C.ggml_cont_2d(c.ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])))
case 3:
return &Tensor{
b: t.b,
t: C.ggml_cont_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])),
}
return poolTensor(c, t.b, C.ggml_cont_3d(c.ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])))
case 4:
return &Tensor{
b: t.b,
t: C.ggml_cont_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])),
}
return poolTensor(c, t.b, C.ggml_cont_4d(c.ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])))
default:
panic("unsupported number of dimensions")
}
}
func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t))
}
func (t *Tensor) Div(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_div(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_div(ctx.(*Context).ctx, t.t, t2.(*Tensor).t))
}
// Mulmat performs matrix multiplication between two tensors.
@ -1257,41 +1248,25 @@ func (t *Tensor) Div(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
//
// Note: this is similar to matmul(t2, t.tranpose(-1, -2)) in other libraries.
func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t))
}
func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
mul := C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t)
C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)
return &Tensor{
b: t.b,
t: mul,
}
return poolTensor(ctx.(*Context), t.b, mul)
}
func (t *Tensor) MulmatID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_mul_mat_id(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, ids.(*Tensor).t),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_mul_mat_id(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, ids.(*Tensor).t))
}
func (t *Tensor) AddID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_add_id(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, ids.(*Tensor).t),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_add_id(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, ids.(*Tensor).t))
}
func (t *Tensor) L2Norm(ctx ml.Context, eps float32) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_l2_norm(ctx.(*Context).ctx, t.t, C.float(eps)),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_l2_norm(ctx.(*Context).ctx, t.t, C.float(eps)))
}
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
@ -1303,7 +1278,7 @@ func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tenso
}
}
return &Tensor{b: t.b, t: tt}
return poolTensor(ctx.(*Context), t.b, tt)
}
func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
@ -1312,7 +1287,7 @@ func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
tt = C.ggml_mul(ctx.(*Context).ctx, tt, w.(*Tensor).t)
}
return &Tensor{b: t.b, t: tt}
return poolTensor(ctx.(*Context), t.b, tt)
}
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
@ -1322,17 +1297,13 @@ func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
panic("cuda does not support 4d tensors")
}
return &Tensor{
b: t.b,
t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
}
return poolTensor(ctx.(*Context), t.b,
C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])))
}
func (t *Tensor) PadExt(ctx ml.Context, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3 int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_pad_ext(ctx.(*Context).ctx, t.t, C.int(lp0), C.int(rp0), C.int(lp1), C.int(rp1), C.int(lp2), C.int(rp2), C.int(lp3), C.int(rp3)),
}
return poolTensor(ctx.(*Context), t.b,
C.ggml_pad_ext(ctx.(*Context).ctx, t.t, C.int(lp0), C.int(rp0), C.int(lp1), C.int(rp1), C.int(lp2), C.int(rp2), C.int(lp3), C.int(rp3)))
}
// Permute permutes t according to order. Permute panics if the number of dimensions
@ -1347,46 +1318,32 @@ func (t *Tensor) Permute(ctx ml.Context, order ...int) ml.Tensor {
order = append(order, i)
}
return &Tensor{
b: t.b,
t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(order[0]), C.int(order[1]), C.int(order[2]), C.int(order[3])),
}
return poolTensor(ctx.(*Context), t.b,
C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(order[0]), C.int(order[1]), C.int(order[2]), C.int(order[3])))
}
func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t))
}
func (t *Tensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_set_rows(ctx.(*Context).ctx, t.t, src.(*Tensor).t, idxs.(*Tensor).t),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_set_rows(ctx.(*Context).ctx, t.t, src.(*Tensor).t, idxs.(*Tensor).t))
}
func (t *Tensor) SetInplace(ctx ml.Context, src ml.Tensor, nb1, nb2, nb3, offset int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_set_inplace(
ctx.(*Context).ctx,
t.t,
src.(*Tensor).t,
C.size_t(nb1),
C.size_t(nb2),
C.size_t(nb3),
C.size_t(offset),
),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_set_inplace(
ctx.(*Context).ctx,
t.t,
src.(*Tensor).t,
C.size_t(nb1),
C.size_t(nb2),
C.size_t(nb3),
C.size_t(offset),
))
}
func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t))
}
// inferShape updates shape in place to automatically set a single -1 dimesion
@ -1430,119 +1387,73 @@ func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
inferShape(t, shape)
}
rc := ctx.(*Context)
switch len(shape) {
case 1:
return &Tensor{
b: t.b,
t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
}
return poolTensor(rc, t.b, C.ggml_reshape_1d(rc.ctx, t.t, C.int64_t(shape[0])))
case 2:
return &Tensor{
b: t.b,
t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])),
}
return poolTensor(rc, t.b, C.ggml_reshape_2d(rc.ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])))
case 3:
return &Tensor{
b: t.b,
t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])),
}
return poolTensor(rc, t.b, C.ggml_reshape_3d(rc.ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])))
case 4:
return &Tensor{
b: t.b,
t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])),
}
return poolTensor(rc, t.b, C.ggml_reshape_4d(rc.ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])))
default:
panic("unsupported number of dimensions")
}
}
func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)))
}
func (t *Tensor) SumRows(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_sum_rows(ctx.(*Context).ctx, t.t),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_sum_rows(ctx.(*Context).ctx, t.t))
}
func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_soft_max(ctx.(*Context).ctx, t.t),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_soft_max(ctx.(*Context).ctx, t.t))
}
func (t *Tensor) Sin(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_sin(ctx.(*Context).ctx, t.t),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_sin(ctx.(*Context).ctx, t.t))
}
func (t *Tensor) Cos(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_cos(ctx.(*Context).ctx, t.t),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_cos(ctx.(*Context).ctx, t.t))
}
func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t))
}
func (t *Tensor) Sigmoid(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_sigmoid_inplace(ctx.(*Context).ctx, t.t),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_sigmoid_inplace(ctx.(*Context).ctx, t.t))
}
func (t *Tensor) SigmoidOut(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_sigmoid(ctx.(*Context).ctx, t.t),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_sigmoid(ctx.(*Context).ctx, t.t))
}
func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
vc := ctx.(*Context)
switch len(shape) {
case 1:
return &Tensor{
b: t.b,
t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)),
}
return poolTensor(vc, t.b, C.ggml_view_1d(vc.ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)))
case 3:
return &Tensor{
b: t.b,
t: C.ggml_view_2d(ctx.(*Context).ctx, t.t,
C.int64_t(shape[0]), C.int64_t(shape[2]),
C.size_t(shape[1]),
C.size_t(offset)),
}
return poolTensor(vc, t.b, C.ggml_view_2d(vc.ctx, t.t,
C.int64_t(shape[0]), C.int64_t(shape[2]),
C.size_t(shape[1]),
C.size_t(offset)))
case 5:
return &Tensor{
b: t.b,
t: C.ggml_view_3d(ctx.(*Context).ctx, t.t,
C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]),
C.size_t(shape[1]), C.size_t(shape[3]),
C.size_t(offset)),
}
return poolTensor(vc, t.b, C.ggml_view_3d(vc.ctx, t.t,
C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]),
C.size_t(shape[1]), C.size_t(shape[3]),
C.size_t(offset)))
case 7:
return &Tensor{
b: t.b,
t: C.ggml_view_4d(ctx.(*Context).ctx, t.t,
C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]),
C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]),
C.size_t(offset)),
}
return poolTensor(vc, t.b, C.ggml_view_4d(vc.ctx, t.t,
C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]),
C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]),
C.size_t(offset)))
default:
panic("unsupported number of dimensions")
}
@ -1602,34 +1513,23 @@ func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase
cmp.Or(C.float(opts.YaRN.BetaSlow), 1),
)
}
return &Tensor{b: t.b, t: tt}
return poolTensor(ctx.(*Context), t.b, tt)
}
func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_im2col(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1), true, C.GGML_TYPE_F32),
}
return poolTensor(ctx.(*Context), t.b,
C.ggml_im2col(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1), true, C.GGML_TYPE_F32))
}
func (t *Tensor) GELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
if len(t2) > 0 {
return &Tensor{
b: t.b,
t: C.ggml_geglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
}
}
return &Tensor{
b: t.b,
t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
return poolTensor(ctx.(*Context), t.b, C.ggml_geglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t))
}
return poolTensor(ctx.(*Context), t.b, C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t))
}
func (t *Tensor) GELU_ERF(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_gelu_erf_inplace(ctx.(*Context).ctx, t.t),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_gelu_erf_inplace(ctx.(*Context).ctx, t.t))
}
func (t *Tensor) QuickGELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
@ -1639,85 +1539,56 @@ func (t *Tensor) QuickGELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
} else {
tt = C.ggml_gelu_quick_inplace(ctx.(*Context).ctx, t.t)
}
return &Tensor{b: t.b, t: tt}
return poolTensor(ctx.(*Context), t.b, tt)
}
func (t *Tensor) SILU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
if len(t2) > 0 {
return &Tensor{
b: t.b,
t: C.ggml_swiglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
}
}
return &Tensor{
b: t.b,
t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
return poolTensor(ctx.(*Context), t.b, C.ggml_swiglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t))
}
return poolTensor(ctx.(*Context), t.b, C.ggml_silu_inplace(ctx.(*Context).ctx, t.t))
}
func (t *Tensor) RELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
if len(t2) > 0 {
return &Tensor{
b: t.b,
t: C.ggml_reglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
}
}
return &Tensor{
b: t.b,
t: C.ggml_relu_inplace(ctx.(*Context).ctx, t.t),
return poolTensor(ctx.(*Context), t.b, C.ggml_reglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t))
}
return poolTensor(ctx.(*Context), t.b, C.ggml_relu_inplace(ctx.(*Context).ctx, t.t))
}
func (t *Tensor) SILUAlphaLimit(ctx ml.Context, up ml.Tensor, alpha, limit float32) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_swiglu_oai(ctx.(*Context).ctx, t.t, up.(*Tensor).t, C.float(alpha), C.float(limit)),
}
return poolTensor(ctx.(*Context), t.b,
C.ggml_swiglu_oai(ctx.(*Context).ctx, t.t, up.(*Tensor).t, C.float(alpha), C.float(limit)))
}
func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)),
}
return poolTensor(ctx.(*Context), t.b,
C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)))
}
func (t *Tensor) Conv1DDW(ctx ml.Context, weight ml.Tensor, s, p, d int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_conv_1d_dw(ctx.(*Context).ctx, weight.(*Tensor).t, t.t, C.int(s), C.int(p), C.int(d)),
}
return poolTensor(ctx.(*Context), t.b,
C.ggml_conv_1d_dw(ctx.(*Context).ctx, weight.(*Tensor).t, t.t, C.int(s), C.int(p), C.int(d)))
}
func (t *Tensor) Conv3D(ctx ml.Context, t2 ml.Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) ml.Tensor {
var tt ml.Tensor = &Tensor{
b: t.b,
t: C.ggml_conv_3d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int64_t(c), C.int(s0), C.int(s1), C.int(s2), C.int(p0), C.int(p1), C.int(p2), C.int(d0), C.int(d1), C.int(d2)),
}
tt = tt.Reshape(ctx, t.Dim(3)/c, t2.Dim(3)/c)
return tt
var tt ml.Tensor = poolTensor(ctx.(*Context), t.b,
C.ggml_conv_3d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int64_t(c), C.int(s0), C.int(s1), C.int(s2), C.int(p0), C.int(p1), C.int(p2), C.int(d0), C.int(d1), C.int(d2)))
return tt.Reshape(ctx, t.Dim(3)/c, t2.Dim(3)/c)
}
func (t *Tensor) SSMConv(ctx ml.Context, kernel ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_ssm_conv(ctx.(*Context).ctx, t.t, kernel.(*Tensor).t),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_ssm_conv(ctx.(*Context).ctx, t.t, kernel.(*Tensor).t))
}
func (t *Tensor) SSMScan(ctx ml.Context, x, dt, A, B, C, ids ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_ssm_scan(ctx.(*Context).ctx, t.t, x.(*Tensor).t, dt.(*Tensor).t, A.(*Tensor).t, B.(*Tensor).t, C.(*Tensor).t, ids.(*Tensor).t),
}
return poolTensor(ctx.(*Context), t.b,
C.ggml_ssm_scan(ctx.(*Context).ctx, t.t, x.(*Tensor).t, dt.(*Tensor).t, A.(*Tensor).t, B.(*Tensor).t, C.(*Tensor).t, ids.(*Tensor).t))
}
func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_pool_2d(ctx.(*Context).ctx, t.t, C.GGML_OP_POOL_AVG, C.int(k), C.int(k), C.int(s), C.int(s), C.float(p), C.float(p)),
}
return poolTensor(ctx.(*Context), t.b,
C.ggml_pool_2d(ctx.(*Context).ctx, t.t, C.GGML_OP_POOL_AVG, C.int(k), C.int(k), C.int(s), C.int(s), C.float(p), C.float(p)))
}
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, vmla ml.Tensor, scale float64, cacheConfigApplied bool) ml.Tensor {
@ -1756,7 +1627,7 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
if vmla != nil {
var cur ml.Tensor = &Tensor{b: t.b, t: kqv}
var cur ml.Tensor = poolTensor(ctx.(*Context), t.b, kqv)
cur = cur.Permute(ctx, 0, 2, 1, 3)
cur = vmla.Mulmat(ctx, cur)
cur = cur.Permute(ctx, 0, 2, 1, 3)
@ -1764,13 +1635,11 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
kqv = cur.(*Tensor).t
}
return &Tensor{b: t.b, t: kqv}
return poolTensor(ctx.(*Context), t.b, kqv)
} else {
kq := key.MulmatFullPrec(ctx, query)
kq = &Tensor{
b: t.b,
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
}
kq = poolTensor(ctx.(*Context), t.b,
C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0))
if sinks != nil {
C.ggml_soft_max_add_sinks(kq.(*Tensor).t, sinks.(*Tensor).t)
}
@ -1785,31 +1654,19 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
}
func (t *Tensor) Duplicate(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_dup(ctx.(*Context).ctx, t.t),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_dup(ctx.(*Context).ctx, t.t))
}
func (t *Tensor) TopK(ctx ml.Context, k int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_argsort_top_k(ctx.(*Context).ctx, t.t, C.int(k)),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_argsort_top_k(ctx.(*Context).ctx, t.t, C.int(k)))
}
func (t *Tensor) Argsort(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_argsort(ctx.(*Context).ctx, t.t, C.GGML_SORT_ORDER_ASC),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_argsort(ctx.(*Context).ctx, t.t, C.GGML_SORT_ORDER_ASC))
}
func (t *Tensor) Mean(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_mean(ctx.(*Context).ctx, t.t),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_mean(ctx.(*Context).ctx, t.t))
}
func (t *Tensor) Variance(ctx ml.Context) ml.Tensor {
@ -1824,87 +1681,53 @@ func (t *Tensor) Stddev(ctx ml.Context) ml.Tensor {
}
func (t *Tensor) Sqr(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_sqr(ctx.(*Context).ctx, t.t),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_sqr(ctx.(*Context).ctx, t.t))
}
func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_sqrt(ctx.(*Context).ctx, t.t),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_sqrt(ctx.(*Context).ctx, t.t))
}
func (t *Tensor) Exp(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_exp(ctx.(*Context).ctx, t.t),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_exp(ctx.(*Context).ctx, t.t))
}
func (t *Tensor) Neg(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_neg(ctx.(*Context).ctx, t.t),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_neg(ctx.(*Context).ctx, t.t))
}
func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)))
}
func (t *Tensor) Softplus(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_softplus(ctx.(*Context).ctx, t.t),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_softplus(ctx.(*Context).ctx, t.t))
}
func (t *Tensor) CumSum(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_cumsum(ctx.(*Context).ctx, t.t),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_cumsum(ctx.(*Context).ctx, t.t))
}
func (t *Tensor) Diag(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_diag(ctx.(*Context).ctx, t.t),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_diag(ctx.(*Context).ctx, t.t))
}
func (t *Tensor) Tri(ctx ml.Context, triType int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_tri(ctx.(*Context).ctx, t.t, C.enum_ggml_tri_type(triType)),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_tri(ctx.(*Context).ctx, t.t, C.enum_ggml_tri_type(triType)))
}
func (t *Tensor) Fill(ctx ml.Context, value float32) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_fill_inplace(ctx.(*Context).ctx, t.t, C.float(value)),
}
return poolTensor(ctx.(*Context), t.b, C.ggml_fill_inplace(ctx.(*Context).ctx, t.t, C.float(value)))
}
func (t *Tensor) Repeat4D(ctx ml.Context, dim0, dim1, dim2, dim3 int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_repeat_4d(ctx.(*Context).ctx, t.t, C.int64_t(dim0), C.int64_t(dim1), C.int64_t(dim2), C.int64_t(dim3)),
}
return poolTensor(ctx.(*Context), t.b,
C.ggml_repeat_4d(ctx.(*Context).ctx, t.t, C.int64_t(dim0), C.int64_t(dim1), C.int64_t(dim2), C.int64_t(dim3)))
}
func (t *Tensor) SolveTri(ctx ml.Context, b ml.Tensor, lower, left, unitDiag bool) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_solve_tri(ctx.(*Context).ctx, t.t, b.(*Tensor).t, C._Bool(lower), C._Bool(left), C._Bool(unitDiag)),
}
return poolTensor(ctx.(*Context), t.b,
C.ggml_solve_tri(ctx.(*Context).ctx, t.t, b.(*Tensor).t, C._Bool(lower), C._Bool(left), C._Bool(unitDiag)))
}
func (t *Tensor) Interpolate(ctx ml.Context, dims [4]int, samplingMode ml.SamplingMode) ml.Tensor {
@ -1918,10 +1741,8 @@ func (t *Tensor) Interpolate(ctx ml.Context, dims [4]int, samplingMode ml.Sampli
panic("unsupported interpolate mode")
}
return &Tensor{
b: t.b,
t: C.ggml_interpolate(ctx.(*Context).ctx, t.t, C.int64_t(dims[0]), C.int64_t(dims[1]), C.int64_t(dims[2]), C.int64_t(dims[3]), mode),
}
return poolTensor(ctx.(*Context), t.b,
C.ggml_interpolate(ctx.(*Context).ctx, t.t, C.int64_t(dims[0]), C.int64_t(dims[1]), C.int64_t(dims[2]), C.int64_t(dims[3]), mode))
}
// Slice returns a view of the tensor sliced along dim from low to high in step steps.

View file

@ -56,7 +56,7 @@ func TestForward(t *testing.T) {
tt = typ.Forward(ctx, tt)
ctx.Forward(tt).Compute(tt)
if diff := cmp.Diff(want, tt.Floats()); diff != "" {
if diff := cmp.Diff(want, tt.Floats(nil)); diff != "" {
t.Error(diff)
}
})

View file

@ -3,10 +3,33 @@ package common
import (
"math"
"sort"
"sync"
"github.com/ollama/ollama/llm"
)
// tokenLogprobPair represents a token ID and its log probability.
type tokenLogprobPair struct {
tokenID int
logprob float32
}
// logprobBuffers holds reusable buffers for logprob calculations.
type logprobBuffers struct {
logProbs []float32
pairs []tokenLogprobPair
}
// logprobPool maintains a pool of reusable buffers to avoid allocations.
var logprobPool = sync.Pool{
New: func() any {
return &logprobBuffers{
logProbs: make([]float32, 0, 200000), // Pre-size for common vocab sizes
pairs: make([]tokenLogprobPair, 0, 200000),
}
},
}
// TokenDecoderFunc is a function that converts token IDs to text.
type TokenDecoderFunc func(tokenID int) string
@ -17,6 +40,16 @@ func CalculateLogprobs(logits []float32, selectedToken int, topK int, decoder To
return nil
}
// Get reusable buffers from pool
bufs := logprobPool.Get().(*logprobBuffers)
defer logprobPool.Put(bufs)
// Grow buffers if needed
if cap(bufs.logProbs) < len(logits) {
bufs.logProbs = make([]float32, len(logits))
}
logProbs := bufs.logProbs[:len(logits)]
// Step 1: Convert logits to log probabilities using numerically stable softmax
maxLogit := logits[0]
for _, logit := range logits[1:] {
@ -31,7 +64,6 @@ func CalculateLogprobs(logits []float32, selectedToken int, topK int, decoder To
}
logSumExp := float32(math.Log(sumExp))
logProbs := make([]float32, len(logits))
for i, logit := range logits {
logProbs[i] = (logit - maxLogit) - logSumExp
}
@ -49,12 +81,12 @@ func CalculateLogprobs(logits []float32, selectedToken int, topK int, decoder To
// Step 3: If topK requested, find the top K tokens
if topK > 0 {
type tokenLogprobPair struct {
tokenID int
logprob float32
// Reuse pairs buffer
if cap(bufs.pairs) < len(logProbs) {
bufs.pairs = make([]tokenLogprobPair, len(logProbs))
}
pairs := bufs.pairs[:len(logProbs)]
pairs := make([]tokenLogprobPair, len(logProbs))
for i, lp := range logProbs {
pairs[i] = tokenLogprobPair{tokenID: i, logprob: lp}
}

View file

@ -94,7 +94,7 @@ func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Ten
for i, t := range entry.mm {
if t.Tensor != nil {
entry.data[i] = t.Tensor.Floats()
entry.data[i] = t.Tensor.Floats(nil)
}
}
} else {

View file

@ -385,6 +385,14 @@ type Server struct {
// multimodalHash generates hashes for comparing equality
// of non-text data
multimodalHash maphash.Hash
// logitsBuf is reused across decode steps to avoid per-request allocations
logitsBuf []float32
// Reusable buffers for computeBatch to avoid per-batch allocations
batchInputsBuf []int32
nextTokensBuf []*input.Input
iBatchesBuf []int
}
func (s *Server) allNil() bool {
@ -655,7 +663,10 @@ func (s *Server) computeBatch(activeBatch batchState) {
s.mu.Lock()
// Gather the actual input token values now that they're ready
batchInputs := make([]int32, len(activeBatch.batchInputs))
if cap(s.batchInputsBuf) < len(activeBatch.batchInputs) {
s.batchInputsBuf = make([]int32, len(activeBatch.batchInputs))
}
batchInputs := s.batchInputsBuf[:len(activeBatch.batchInputs)]
for i := range batchInputs {
batchInputs[i] = activeBatch.batchInputs[i].Token
}
@ -663,8 +674,19 @@ func (s *Server) computeBatch(activeBatch batchState) {
// Now we run part of the decoding algorithm to adjust the seq.inputs with placeholder tokens
// so that forwardBatch can build a batchInputs set which will eventually contain the actual
// decoded tokens.
nextBatchTokens := make([]*input.Input, len(s.seqs))
iBatches := make([]int, len(s.seqs)) // Record the iBatch values before releasing the lock
if cap(s.nextTokensBuf) < len(s.seqs) {
s.nextTokensBuf = make([]*input.Input, len(s.seqs))
}
nextBatchTokens := s.nextTokensBuf[:len(s.seqs)]
// Clear buffer (important: set to nil, not reuse old pointers)
for i := range nextBatchTokens {
nextBatchTokens[i] = nil
}
if cap(s.iBatchesBuf) < len(s.seqs) {
s.iBatchesBuf = make([]int, len(s.seqs))
}
iBatches := s.iBatchesBuf[:len(s.seqs)] // Record the iBatch values before releasing the lock
for i, seq := range s.seqs {
iBatches[i] = -1
if seq == nil {
@ -720,7 +742,8 @@ func (s *Server) computeBatch(activeBatch batchState) {
},
activeBatch.modelOutput)
outputs := activeBatch.modelOutput.Floats()
s.logitsBuf = activeBatch.modelOutput.Floats(s.logitsBuf)
outputs := s.logitsBuf
t := time.Now()
logutil.Trace("computeBatch: logits ready", "batchID", activeBatch.id)

View file

@ -23,6 +23,7 @@ type Sampler struct {
minP float32
temperature float32
grammar *GrammarSampler
tokenBuf []token
}
func (s *Sampler) Sample(logits []float32) (int32, error) {
@ -30,7 +31,10 @@ func (s *Sampler) Sample(logits []float32) (int32, error) {
return -1, errors.New("sample: no logits provided to sample")
}
tokens := make([]token, len(logits))
if cap(s.tokenBuf) < len(logits) {
s.tokenBuf = make([]token, len(logits))
}
tokens := s.tokenBuf[:len(logits)]
for i := range logits {
tokens[i].id = int32(i)
tokens[i].value = logits[i]
@ -165,7 +169,8 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
}
type GrammarSampler struct {
grammar *llama.Grammar
grammar *llama.Grammar
tokenDataBuf []llama.TokenData
}
func NewGrammarSampler(tok tokenizer.Tokenizer, grammarStr string) (*GrammarSampler, error) {
@ -185,7 +190,11 @@ func NewGrammarSampler(tok tokenizer.Tokenizer, grammarStr string) (*GrammarSamp
}
func (g *GrammarSampler) Apply(tokens []token) {
tds := make([]llama.TokenData, len(tokens))
if cap(g.tokenDataBuf) < len(tokens) {
g.tokenDataBuf = make([]llama.TokenData, len(tokens))
}
tds := g.tokenDataBuf[:len(tokens)]
for i, token := range tokens {
tds[i].ID = token.id
tds[i].Logit = token.value

View file

@ -1,30 +1,10 @@
package sample
import (
"container/heap"
"math"
"slices"
)
// tokenHeap implements heap.Interface and holds tokens as a min-heap to track k largest elements
type tokenHeap []token
func (h tokenHeap) Len() int { return len(h) }
func (h tokenHeap) Less(i, j int) bool { return h[i].value < h[j].value }
func (h tokenHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *tokenHeap) Push(x any) {
*h = append(*h, x.(token))
}
func (h *tokenHeap) Pop() any {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}
// temperature applies scaling to the logits
func temperature(ts []token, temp float32) {
// Ensure temperature clipping near 0 to avoid numerical instability
@ -59,40 +39,20 @@ func softmax(ts []token) {
// topK limits the number of tokens considered to the k highest logits
func topK(ts []token, k int) []token {
if k >= len(ts) || k <= 0 {
slices.SortFunc(ts, func(a, b token) int {
switch {
case a.value < b.value:
return 1
case a.value > b.value:
return -1
default:
return 0
}
})
slices.SortFunc(ts, func(a, b token) int {
switch {
case a.value < b.value:
return 1
case a.value > b.value:
return -1
default:
return 0
}
})
if k <= 0 || k >= len(ts) {
return ts
}
// Initialize min-heap with first k elements
h := make(tokenHeap, k)
copy(h, ts[:k])
heap.Init(&h)
// Process remaining elements
for i := k; i < len(ts); i++ {
if ts[i].value > h[0].value {
heap.Pop(&h)
heap.Push(&h, ts[i])
}
}
// Convert heap to sorted slice in descending order
result := make([]token, len(h))
for i := k - 1; i >= 0; i-- {
result[i] = heap.Pop(&h).(token)
}
return result
return ts[:k]
}
// topP limits tokens to those with cumulative probability p