diff --git a/MLX_VERSION b/MLX_VERSION index b043aa648..5aff472dd 100644 --- a/MLX_VERSION +++ b/MLX_VERSION @@ -1 +1 @@ -v0.5.0 +v0.4.1 diff --git a/x/imagegen/mlx/generate_wrappers.go b/x/imagegen/mlx/generate_wrappers.go index 8aa5bd0c8..a55def02b 100644 --- a/x/imagegen/mlx/generate_wrappers.go +++ b/x/imagegen/mlx/generate_wrappers.go @@ -16,10 +16,10 @@ import ( ) type Function struct { - Name string - ReturnType string - Params string - ParamNames []string + Name string + ReturnType string + Params string + ParamNames []string NeedsARM64Guard bool } @@ -29,11 +29,6 @@ func findHeaders(directory string) ([]string, error) { if err != nil { return err } - // Private headers contain C++ implementation helpers and are not part of - // the C API surface; parsing them can produce invalid wrapper signatures. - if d.IsDir() && d.Name() == "private" { - return fs.SkipDir - } if !d.IsDir() && strings.HasSuffix(path, ".h") { headers = append(headers, path) } @@ -199,10 +194,10 @@ func parseFunctions(content string) []Function { needsGuard := needsARM64Guard(funcName, returnType, params) functions = append(functions, Function{ - Name: funcName, - ReturnType: returnType, - Params: params, - ParamNames: paramNames, + Name: funcName, + ReturnType: returnType, + Params: params, + ParamNames: paramNames, NeedsARM64Guard: needsGuard, }) } diff --git a/x/imagegen/mlx/mlx.c b/x/imagegen/mlx/mlx.c index 770b60922..564076f30 100644 --- a/x/imagegen/mlx/mlx.c +++ b/x/imagegen/mlx/mlx.c @@ -20,8 +20,6 @@ mlx_array (*mlx_array_new_float64_ptr)(double val) = NULL; mlx_array (*mlx_array_new_double_ptr)(double val) = NULL; mlx_array (*mlx_array_new_complex_ptr)(float real_val, float imag_val) = NULL; mlx_array (*mlx_array_new_data_ptr)(const void* data, const int* shape, int dim, mlx_dtype dtype) = NULL; -mlx_array (*mlx_array_new_data_managed_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)) = NULL; -mlx_array (*mlx_array_new_data_managed_payload_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)) = NULL; int (*mlx_array_set_ptr)(mlx_array* arr, const mlx_array src) = NULL; int (*mlx_array_set_bool_ptr)(mlx_array* arr, bool val) = NULL; int (*mlx_array_set_int_ptr)(mlx_array* arr, int val) = NULL; @@ -51,7 +49,7 @@ int (*mlx_array_item_int32_ptr)(int32_t* res, const mlx_array arr) = NULL; int (*mlx_array_item_int64_ptr)(int64_t* res, const mlx_array arr) = NULL; int (*mlx_array_item_float32_ptr)(float* res, const mlx_array arr) = NULL; int (*mlx_array_item_float64_ptr)(double* res, const mlx_array arr) = NULL; -int (*mlx_array_item_complex64_ptr)(mlx_complex64_t* res, const mlx_array arr) = NULL; +int (*mlx_array_item_complex64_ptr)(float _Complex* res, const mlx_array arr) = NULL; #if defined(__aarch64__) || defined(_M_ARM64) int (*mlx_array_item_float16_ptr)(float16_t* res, const mlx_array arr) = NULL; #endif @@ -69,7 +67,7 @@ const int32_t* (*mlx_array_data_int32_ptr)(const mlx_array arr) = NULL; const int64_t* (*mlx_array_data_int64_ptr)(const mlx_array arr) = NULL; const float* (*mlx_array_data_float32_ptr)(const mlx_array arr) = NULL; const double* (*mlx_array_data_float64_ptr)(const mlx_array arr) = NULL; -const mlx_complex64_t* (*mlx_array_data_complex64_ptr)(const mlx_array arr) = NULL; +const float _Complex* (*mlx_array_data_complex64_ptr)(const mlx_array arr) = NULL; #if defined(__aarch64__) || defined(_M_ARM64) const float16_t* (*mlx_array_data_float16_ptr)(const mlx_array arr) = NULL; #endif @@ -125,7 +123,6 @@ int (*mlx_detail_compile_erase_ptr)(uintptr_t fun_id) = NULL; int (*mlx_disable_compile_ptr)(void) = NULL; int (*mlx_enable_compile_ptr)(void) = NULL; int (*mlx_set_compile_mode_ptr)(mlx_compile_mode mode) = NULL; -int (*mlx_cuda_is_available_ptr)(bool* res) = NULL; mlx_device (*mlx_device_new_ptr)(void) = NULL; mlx_device (*mlx_device_new_type_ptr)(mlx_device_type type, int index) = NULL; int (*mlx_device_free_ptr)(mlx_device dev) = NULL; @@ -136,16 +133,6 @@ int (*mlx_device_get_index_ptr)(int* index, mlx_device dev) = NULL; int (*mlx_device_get_type_ptr)(mlx_device_type* type, mlx_device dev) = NULL; int (*mlx_get_default_device_ptr)(mlx_device* dev) = NULL; int (*mlx_set_default_device_ptr)(mlx_device dev) = NULL; -int (*mlx_device_is_available_ptr)(bool* avail, mlx_device dev) = NULL; -int (*mlx_device_count_ptr)(int* count, mlx_device_type type) = NULL; -mlx_device_info (*mlx_device_info_new_ptr)(void) = NULL; -int (*mlx_device_info_get_ptr)(mlx_device_info* info, mlx_device dev) = NULL; -int (*mlx_device_info_free_ptr)(mlx_device_info info) = NULL; -int (*mlx_device_info_has_key_ptr)(bool* exists, mlx_device_info info, const char* key) = NULL; -int (*mlx_device_info_is_string_ptr)(bool* is_string, mlx_device_info info, const char* key) = NULL; -int (*mlx_device_info_get_string_ptr)(const char** value, mlx_device_info info, const char* key) = NULL; -int (*mlx_device_info_get_size_ptr)(size_t* value, mlx_device_info info, const char* key) = NULL; -int (*mlx_device_info_get_keys_ptr)(mlx_vector_string* keys, mlx_device_info info) = NULL; int (*mlx_distributed_all_gather_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) = NULL; int (*mlx_distributed_all_max_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL; int (*mlx_distributed_all_min_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL; @@ -276,6 +263,7 @@ int (*mlx_reset_peak_memory_ptr)(void) = NULL; int (*mlx_set_cache_limit_ptr)(size_t* res, size_t limit) = NULL; int (*mlx_set_memory_limit_ptr)(size_t* res, size_t limit) = NULL; int (*mlx_set_wired_limit_ptr)(size_t* res, size_t limit) = NULL; +mlx_metal_device_info_t (*mlx_metal_device_info_ptr)(void) = NULL; int (*mlx_metal_is_available_ptr)(bool* res) = NULL; int (*mlx_metal_start_capture_ptr)(const char* path) = NULL; int (*mlx_metal_stop_capture_ptr)(void) = NULL; @@ -670,16 +658,6 @@ int mlx_load_functions(void* handle) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data\n"); return -1; } - mlx_array_new_data_managed_ptr = dlsym(handle, "mlx_array_new_data_managed"); - if (mlx_array_new_data_managed_ptr == NULL) { - fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed\n"); - return -1; - } - mlx_array_new_data_managed_payload_ptr = dlsym(handle, "mlx_array_new_data_managed_payload"); - if (mlx_array_new_data_managed_payload_ptr == NULL) { - fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed_payload\n"); - return -1; - } mlx_array_set_ptr = dlsym(handle, "mlx_array_set"); if (mlx_array_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set\n"); @@ -1163,11 +1141,6 @@ int mlx_load_functions(void* handle) { fprintf(stderr, "MLX: Failed to load symbol: mlx_set_compile_mode\n"); return -1; } - mlx_cuda_is_available_ptr = dlsym(handle, "mlx_cuda_is_available"); - if (mlx_cuda_is_available_ptr == NULL) { - fprintf(stderr, "MLX: Failed to load symbol: mlx_cuda_is_available\n"); - return -1; - } mlx_device_new_ptr = dlsym(handle, "mlx_device_new"); if (mlx_device_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_new\n"); @@ -1218,56 +1191,6 @@ int mlx_load_functions(void* handle) { fprintf(stderr, "MLX: Failed to load symbol: mlx_set_default_device\n"); return -1; } - mlx_device_is_available_ptr = dlsym(handle, "mlx_device_is_available"); - if (mlx_device_is_available_ptr == NULL) { - fprintf(stderr, "MLX: Failed to load symbol: mlx_device_is_available\n"); - return -1; - } - mlx_device_count_ptr = dlsym(handle, "mlx_device_count"); - if (mlx_device_count_ptr == NULL) { - fprintf(stderr, "MLX: Failed to load symbol: mlx_device_count\n"); - return -1; - } - mlx_device_info_new_ptr = dlsym(handle, "mlx_device_info_new"); - if (mlx_device_info_new_ptr == NULL) { - fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_new\n"); - return -1; - } - mlx_device_info_get_ptr = dlsym(handle, "mlx_device_info_get"); - if (mlx_device_info_get_ptr == NULL) { - fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get\n"); - return -1; - } - mlx_device_info_free_ptr = dlsym(handle, "mlx_device_info_free"); - if (mlx_device_info_free_ptr == NULL) { - fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_free\n"); - return -1; - } - mlx_device_info_has_key_ptr = dlsym(handle, "mlx_device_info_has_key"); - if (mlx_device_info_has_key_ptr == NULL) { - fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_has_key\n"); - return -1; - } - mlx_device_info_is_string_ptr = dlsym(handle, "mlx_device_info_is_string"); - if (mlx_device_info_is_string_ptr == NULL) { - fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_is_string\n"); - return -1; - } - mlx_device_info_get_string_ptr = dlsym(handle, "mlx_device_info_get_string"); - if (mlx_device_info_get_string_ptr == NULL) { - fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_string\n"); - return -1; - } - mlx_device_info_get_size_ptr = dlsym(handle, "mlx_device_info_get_size"); - if (mlx_device_info_get_size_ptr == NULL) { - fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_size\n"); - return -1; - } - mlx_device_info_get_keys_ptr = dlsym(handle, "mlx_device_info_get_keys"); - if (mlx_device_info_get_keys_ptr == NULL) { - fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_keys\n"); - return -1; - } mlx_distributed_all_gather_ptr = dlsym(handle, "mlx_distributed_all_gather"); if (mlx_distributed_all_gather_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_gather\n"); @@ -1918,6 +1841,11 @@ int mlx_load_functions(void* handle) { fprintf(stderr, "MLX: Failed to load symbol: mlx_set_wired_limit\n"); return -1; } + mlx_metal_device_info_ptr = dlsym(handle, "mlx_metal_device_info"); + if (mlx_metal_device_info_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_device_info\n"); + return -1; + } mlx_metal_is_available_ptr = dlsym(handle, "mlx_metal_is_available"); if (mlx_metal_is_available_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_is_available\n"); @@ -3600,14 +3528,6 @@ mlx_array mlx_array_new_data(const void* data, const int* shape, int dim, mlx_dt return mlx_array_new_data_ptr(data, shape, dim, dtype); } -mlx_array mlx_array_new_data_managed(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)) { - return mlx_array_new_data_managed_ptr(data, shape, dim, dtype, dtor); -} - -mlx_array mlx_array_new_data_managed_payload(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)) { - return mlx_array_new_data_managed_payload_ptr(data, shape, dim, dtype, payload, dtor); -} - int mlx_array_set(mlx_array* arr, const mlx_array src) { return mlx_array_set_ptr(arr, src); } @@ -3724,7 +3644,7 @@ int mlx_array_item_float64(double* res, const mlx_array arr) { return mlx_array_item_float64_ptr(res, arr); } -int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr) { +int mlx_array_item_complex64(float _Complex* res, const mlx_array arr) { return mlx_array_item_complex64_ptr(res, arr); } @@ -3784,7 +3704,7 @@ const double* mlx_array_data_float64(const mlx_array arr) { return mlx_array_data_float64_ptr(arr); } -const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr) { +const float _Complex* mlx_array_data_complex64(const mlx_array arr) { return mlx_array_data_complex64_ptr(arr); } @@ -3996,10 +3916,6 @@ int mlx_set_compile_mode(mlx_compile_mode mode) { return mlx_set_compile_mode_ptr(mode); } -int mlx_cuda_is_available(bool* res) { - return mlx_cuda_is_available_ptr(res); -} - mlx_device mlx_device_new(void) { return mlx_device_new_ptr(); } @@ -4040,46 +3956,6 @@ int mlx_set_default_device(mlx_device dev) { return mlx_set_default_device_ptr(dev); } -int mlx_device_is_available(bool* avail, mlx_device dev) { - return mlx_device_is_available_ptr(avail, dev); -} - -int mlx_device_count(int* count, mlx_device_type type) { - return mlx_device_count_ptr(count, type); -} - -mlx_device_info mlx_device_info_new(void) { - return mlx_device_info_new_ptr(); -} - -int mlx_device_info_get(mlx_device_info* info, mlx_device dev) { - return mlx_device_info_get_ptr(info, dev); -} - -int mlx_device_info_free(mlx_device_info info) { - return mlx_device_info_free_ptr(info); -} - -int mlx_device_info_has_key(bool* exists, mlx_device_info info, const char* key) { - return mlx_device_info_has_key_ptr(exists, info, key); -} - -int mlx_device_info_is_string(bool* is_string, mlx_device_info info, const char* key) { - return mlx_device_info_is_string_ptr(is_string, info, key); -} - -int mlx_device_info_get_string(const char** value, mlx_device_info info, const char* key) { - return mlx_device_info_get_string_ptr(value, info, key); -} - -int mlx_device_info_get_size(size_t* value, mlx_device_info info, const char* key) { - return mlx_device_info_get_size_ptr(value, info, key); -} - -int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info) { - return mlx_device_info_get_keys_ptr(keys, info); -} - int mlx_distributed_all_gather(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) { return mlx_distributed_all_gather_ptr(res, x, group, S); } @@ -4600,6 +4476,10 @@ int mlx_set_wired_limit(size_t* res, size_t limit) { return mlx_set_wired_limit_ptr(res, limit); } +mlx_metal_device_info_t mlx_metal_device_info(void) { + return mlx_metal_device_info_ptr(); +} + int mlx_metal_is_available(bool* res) { return mlx_metal_is_available_ptr(res); } diff --git a/x/imagegen/mlx/mlx.h b/x/imagegen/mlx/mlx.h index 34829d732..d4ed1a905 100644 --- a/x/imagegen/mlx/mlx.h +++ b/x/imagegen/mlx/mlx.h @@ -26,8 +26,6 @@ #undef mlx_array_new_double #undef mlx_array_new_complex #undef mlx_array_new_data -#undef mlx_array_new_data_managed -#undef mlx_array_new_data_managed_payload #undef mlx_array_set #undef mlx_array_set_bool #undef mlx_array_set_int @@ -123,7 +121,6 @@ #undef mlx_disable_compile #undef mlx_enable_compile #undef mlx_set_compile_mode -#undef mlx_cuda_is_available #undef mlx_device_new #undef mlx_device_new_type #undef mlx_device_free @@ -134,16 +131,6 @@ #undef mlx_device_get_type #undef mlx_get_default_device #undef mlx_set_default_device -#undef mlx_device_is_available -#undef mlx_device_count -#undef mlx_device_info_new -#undef mlx_device_info_get -#undef mlx_device_info_free -#undef mlx_device_info_has_key -#undef mlx_device_info_is_string -#undef mlx_device_info_get_string -#undef mlx_device_info_get_size -#undef mlx_device_info_get_keys #undef mlx_distributed_all_gather #undef mlx_distributed_all_max #undef mlx_distributed_all_min @@ -274,6 +261,7 @@ #undef mlx_set_cache_limit #undef mlx_set_memory_limit #undef mlx_set_wired_limit +#undef mlx_metal_device_info #undef mlx_metal_is_available #undef mlx_metal_start_capture #undef mlx_metal_stop_capture @@ -614,8 +602,6 @@ extern mlx_array (*mlx_array_new_float64_ptr)(double val); extern mlx_array (*mlx_array_new_double_ptr)(double val); extern mlx_array (*mlx_array_new_complex_ptr)(float real_val, float imag_val); extern mlx_array (*mlx_array_new_data_ptr)(const void* data, const int* shape, int dim, mlx_dtype dtype); -extern mlx_array (*mlx_array_new_data_managed_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)); -extern mlx_array (*mlx_array_new_data_managed_payload_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)); extern int (*mlx_array_set_ptr)(mlx_array* arr, const mlx_array src); extern int (*mlx_array_set_bool_ptr)(mlx_array* arr, bool val); extern int (*mlx_array_set_int_ptr)(mlx_array* arr, int val); @@ -645,7 +631,7 @@ extern int (*mlx_array_item_int32_ptr)(int32_t* res, const mlx_array arr); extern int (*mlx_array_item_int64_ptr)(int64_t* res, const mlx_array arr); extern int (*mlx_array_item_float32_ptr)(float* res, const mlx_array arr); extern int (*mlx_array_item_float64_ptr)(double* res, const mlx_array arr); -extern int (*mlx_array_item_complex64_ptr)(mlx_complex64_t* res, const mlx_array arr); +extern int (*mlx_array_item_complex64_ptr)(float _Complex* res, const mlx_array arr); #if defined(__aarch64__) || defined(_M_ARM64) extern int (*mlx_array_item_float16_ptr)(float16_t* res, const mlx_array arr); #endif @@ -663,7 +649,7 @@ extern const int32_t* (*mlx_array_data_int32_ptr)(const mlx_array arr); extern const int64_t* (*mlx_array_data_int64_ptr)(const mlx_array arr); extern const float* (*mlx_array_data_float32_ptr)(const mlx_array arr); extern const double* (*mlx_array_data_float64_ptr)(const mlx_array arr); -extern const mlx_complex64_t* (*mlx_array_data_complex64_ptr)(const mlx_array arr); +extern const float _Complex* (*mlx_array_data_complex64_ptr)(const mlx_array arr); #if defined(__aarch64__) || defined(_M_ARM64) extern const float16_t* (*mlx_array_data_float16_ptr)(const mlx_array arr); #endif @@ -719,7 +705,6 @@ extern int (*mlx_detail_compile_erase_ptr)(uintptr_t fun_id); extern int (*mlx_disable_compile_ptr)(void); extern int (*mlx_enable_compile_ptr)(void); extern int (*mlx_set_compile_mode_ptr)(mlx_compile_mode mode); -extern int (*mlx_cuda_is_available_ptr)(bool* res); extern mlx_device (*mlx_device_new_ptr)(void); extern mlx_device (*mlx_device_new_type_ptr)(mlx_device_type type, int index); extern int (*mlx_device_free_ptr)(mlx_device dev); @@ -730,16 +715,6 @@ extern int (*mlx_device_get_index_ptr)(int* index, mlx_device dev); extern int (*mlx_device_get_type_ptr)(mlx_device_type* type, mlx_device dev); extern int (*mlx_get_default_device_ptr)(mlx_device* dev); extern int (*mlx_set_default_device_ptr)(mlx_device dev); -extern int (*mlx_device_is_available_ptr)(bool* avail, mlx_device dev); -extern int (*mlx_device_count_ptr)(int* count, mlx_device_type type); -extern mlx_device_info (*mlx_device_info_new_ptr)(void); -extern int (*mlx_device_info_get_ptr)(mlx_device_info* info, mlx_device dev); -extern int (*mlx_device_info_free_ptr)(mlx_device_info info); -extern int (*mlx_device_info_has_key_ptr)(bool* exists, mlx_device_info info, const char* key); -extern int (*mlx_device_info_is_string_ptr)(bool* is_string, mlx_device_info info, const char* key); -extern int (*mlx_device_info_get_string_ptr)(const char** value, mlx_device_info info, const char* key); -extern int (*mlx_device_info_get_size_ptr)(size_t* value, mlx_device_info info, const char* key); -extern int (*mlx_device_info_get_keys_ptr)(mlx_vector_string* keys, mlx_device_info info); extern int (*mlx_distributed_all_gather_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S); extern int (*mlx_distributed_all_max_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s); extern int (*mlx_distributed_all_min_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s); @@ -870,6 +845,7 @@ extern int (*mlx_reset_peak_memory_ptr)(void); extern int (*mlx_set_cache_limit_ptr)(size_t* res, size_t limit); extern int (*mlx_set_memory_limit_ptr)(size_t* res, size_t limit); extern int (*mlx_set_wired_limit_ptr)(size_t* res, size_t limit); +extern mlx_metal_device_info_t (*mlx_metal_device_info_ptr)(void); extern int (*mlx_metal_is_available_ptr)(bool* res); extern int (*mlx_metal_start_capture_ptr)(const char* path); extern int (*mlx_metal_stop_capture_ptr)(void); @@ -1226,10 +1202,6 @@ mlx_array mlx_array_new_complex(float real_val, float imag_val); mlx_array mlx_array_new_data(const void* data, const int* shape, int dim, mlx_dtype dtype); -mlx_array mlx_array_new_data_managed(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)); - -mlx_array mlx_array_new_data_managed_payload(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)); - int mlx_array_set(mlx_array* arr, const mlx_array src); int mlx_array_set_bool(mlx_array* arr, bool val); @@ -1288,7 +1260,7 @@ int mlx_array_item_float32(float* res, const mlx_array arr); int mlx_array_item_float64(double* res, const mlx_array arr); -int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr); +int mlx_array_item_complex64(float _Complex* res, const mlx_array arr); #if defined(__aarch64__) || defined(_M_ARM64) int mlx_array_item_float16(float16_t* res, const mlx_array arr); @@ -1320,7 +1292,7 @@ const float* mlx_array_data_float32(const mlx_array arr); const double* mlx_array_data_float64(const mlx_array arr); -const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr); +const float _Complex* mlx_array_data_complex64(const mlx_array arr); #if defined(__aarch64__) || defined(_M_ARM64) const float16_t* mlx_array_data_float16(const mlx_array arr); @@ -1428,8 +1400,6 @@ int mlx_enable_compile(void); int mlx_set_compile_mode(mlx_compile_mode mode); -int mlx_cuda_is_available(bool* res); - mlx_device mlx_device_new(void); mlx_device mlx_device_new_type(mlx_device_type type, int index); @@ -1450,26 +1420,6 @@ int mlx_get_default_device(mlx_device* dev); int mlx_set_default_device(mlx_device dev); -int mlx_device_is_available(bool* avail, mlx_device dev); - -int mlx_device_count(int* count, mlx_device_type type); - -mlx_device_info mlx_device_info_new(void); - -int mlx_device_info_get(mlx_device_info* info, mlx_device dev); - -int mlx_device_info_free(mlx_device_info info); - -int mlx_device_info_has_key(bool* exists, mlx_device_info info, const char* key); - -int mlx_device_info_is_string(bool* is_string, mlx_device_info info, const char* key); - -int mlx_device_info_get_string(const char** value, mlx_device_info info, const char* key); - -int mlx_device_info_get_size(size_t* value, mlx_device_info info, const char* key); - -int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info); - int mlx_distributed_all_gather(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S); int mlx_distributed_all_max(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s); @@ -1730,6 +1680,8 @@ int mlx_set_memory_limit(size_t* res, size_t limit); int mlx_set_wired_limit(size_t* res, size_t limit); +mlx_metal_device_info_t mlx_metal_device_info(void); + int mlx_metal_is_available(bool* res); int mlx_metal_start_capture(const char* path); diff --git a/x/mlxrunner/mlx/CMakeLists.txt b/x/mlxrunner/mlx/CMakeLists.txt index 1ca13bdaf..c41ce46f7 100644 --- a/x/mlxrunner/mlx/CMakeLists.txt +++ b/x/mlxrunner/mlx/CMakeLists.txt @@ -15,7 +15,7 @@ set(CMAKE_INSTALL_RPATH "@loader_path") include(FetchContent) -set(MLX_C_GIT_TAG "v0.5.0" CACHE STRING "") +set(MLX_C_GIT_TAG "v0.4.1" CACHE STRING "") FetchContent_Declare( mlx-c diff --git a/x/mlxrunner/mlx/generated.c b/x/mlxrunner/mlx/generated.c index 29d1330af..af99b631e 100644 --- a/x/mlxrunner/mlx/generated.c +++ b/x/mlxrunner/mlx/generated.c @@ -22,19 +22,6 @@ mlx_array (*mlx_array_new_data_)( const int* shape, int dim, mlx_dtype dtype) = NULL; -mlx_array (*mlx_array_new_data_managed_)( - void* data, - const int* shape, - int dim, - mlx_dtype dtype, - void (*dtor)(void*)) = NULL; -mlx_array (*mlx_array_new_data_managed_payload_)( - void* data, - const int* shape, - int dim, - mlx_dtype dtype, - void* payload, - void (*dtor)(void*)) = NULL; int (*mlx_array_set_)(mlx_array* arr, const mlx_array src) = NULL; int (*mlx_array_set_bool_)(mlx_array* arr, bool val) = NULL; int (*mlx_array_set_int_)(mlx_array* arr, int val) = NULL; @@ -69,7 +56,7 @@ int (*mlx_array_item_int32_)(int32_t* res, const mlx_array arr) = NULL; int (*mlx_array_item_int64_)(int64_t* res, const mlx_array arr) = NULL; int (*mlx_array_item_float32_)(float* res, const mlx_array arr) = NULL; int (*mlx_array_item_float64_)(double* res, const mlx_array arr) = NULL; -int (*mlx_array_item_complex64_)(mlx_complex64_t* res, const mlx_array arr) = NULL; +int (*mlx_array_item_complex64_)(float _Complex* res, const mlx_array arr) = NULL; int (*mlx_array_item_float16_)(float16_t* res, const mlx_array arr) = NULL; int (*mlx_array_item_bfloat16_)(bfloat16_t* res, const mlx_array arr) = NULL; const bool * (*mlx_array_data_bool_)(const mlx_array arr) = NULL; @@ -83,7 +70,7 @@ const int32_t * (*mlx_array_data_int32_)(const mlx_array arr) = NULL; const int64_t * (*mlx_array_data_int64_)(const mlx_array arr) = NULL; const float * (*mlx_array_data_float32_)(const mlx_array arr) = NULL; const double * (*mlx_array_data_float64_)(const mlx_array arr) = NULL; -const mlx_complex64_t * (*mlx_array_data_complex64_)(const mlx_array arr) = NULL; +const float _Complex * (*mlx_array_data_complex64_)(const mlx_array arr) = NULL; const float16_t * (*mlx_array_data_float16_)(const mlx_array arr) = NULL; const bfloat16_t * (*mlx_array_data_bfloat16_)(const mlx_array arr) = NULL; int (*_mlx_array_is_available_)(bool* res, const mlx_array arr) = NULL; @@ -107,11 +94,10 @@ int (*mlx_closure_apply_)( mlx_closure (*mlx_closure_new_unary_)(int (*fun)(mlx_array*, const mlx_array)) = NULL; mlx_closure_kwargs (*mlx_closure_kwargs_new_)(void) = NULL; int (*mlx_closure_kwargs_free_)(mlx_closure_kwargs cls) = NULL; -mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)( - int (*fun)( - mlx_vector_array*, - const mlx_vector_array, - const mlx_map_string_to_array)) = NULL; +mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)(int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_map_string_to_array)) = NULL; mlx_closure_kwargs (*mlx_closure_kwargs_new_func_payload_)( int (*fun)( mlx_vector_array*, @@ -150,12 +136,11 @@ int (*mlx_closure_value_and_grad_apply_)( const mlx_vector_array input) = NULL; mlx_closure_custom (*mlx_closure_custom_new_)(void) = NULL; int (*mlx_closure_custom_free_)(mlx_closure_custom cls) = NULL; -mlx_closure_custom (*mlx_closure_custom_new_func_)( - int (*fun)( - mlx_vector_array*, - const mlx_vector_array, - const mlx_vector_array, - const mlx_vector_array)) = NULL; +mlx_closure_custom (*mlx_closure_custom_new_func_)(int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const mlx_vector_array)) = NULL; mlx_closure_custom (*mlx_closure_custom_new_func_payload_)( int (*fun)( mlx_vector_array*, @@ -176,13 +161,12 @@ int (*mlx_closure_custom_apply_)( const mlx_vector_array input_2) = NULL; mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_)(void) = NULL; int (*mlx_closure_custom_jvp_free_)(mlx_closure_custom_jvp cls) = NULL; -mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)( - int (*fun)( - mlx_vector_array*, - const mlx_vector_array, - const mlx_vector_array, - const int*, - size_t _num)) = NULL; +mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)(int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const int*, + size_t _num)) = NULL; mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_payload_)( int (*fun)( mlx_vector_array*, @@ -205,13 +189,12 @@ int (*mlx_closure_custom_jvp_apply_)( size_t input_2_num) = NULL; mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_)(void) = NULL; int (*mlx_closure_custom_vmap_free_)(mlx_closure_custom_vmap cls) = NULL; -mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)( - int (*fun)( - mlx_vector_array*, - mlx_vector_int*, - const mlx_vector_array, - const int*, - size_t _num)) = NULL; +mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)(int (*fun)( + mlx_vector_array*, + mlx_vector_int*, + const mlx_vector_array, + const int*, + size_t _num)) = NULL; mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_payload_)( int (*fun)( mlx_vector_array*, @@ -245,7 +228,6 @@ int (*mlx_detail_compile_erase_)(uintptr_t fun_id) = NULL; int (*mlx_disable_compile_)(void) = NULL; int (*mlx_enable_compile_)(void) = NULL; int (*mlx_set_compile_mode_)(mlx_compile_mode mode) = NULL; -int (*mlx_cuda_is_available_)(bool* res) = NULL; mlx_device (*mlx_device_new_)(void) = NULL; mlx_device (*mlx_device_new_type_)(mlx_device_type type, int index) = NULL; int (*mlx_device_free_)(mlx_device dev) = NULL; @@ -256,28 +238,11 @@ int (*mlx_device_get_index_)(int* index, mlx_device dev) = NULL; int (*mlx_device_get_type_)(mlx_device_type* type, mlx_device dev) = NULL; int (*mlx_get_default_device_)(mlx_device* dev) = NULL; int (*mlx_set_default_device_)(mlx_device dev) = NULL; -int (*mlx_device_is_available_)(bool* avail, mlx_device dev) = NULL; -int (*mlx_device_count_)(int* count, mlx_device_type type) = NULL; -mlx_device_info (*mlx_device_info_new_)(void) = NULL; -int (*mlx_device_info_get_)(mlx_device_info* info, mlx_device dev) = NULL; -int (*mlx_device_info_free_)(mlx_device_info info) = NULL; -int (*mlx_device_info_has_key_)( - bool* exists, - mlx_device_info info, - const char* key) = NULL; -int (*mlx_device_info_is_string_)( - bool* is_string, - mlx_device_info info, - const char* key) = NULL; -int (*mlx_device_info_get_string_)( - const char** value, - mlx_device_info info, - const char* key) = NULL; -int (*mlx_device_info_get_size_)( - size_t* value, - mlx_device_info info, - const char* key) = NULL; -int (*mlx_device_info_get_keys_)(mlx_vector_string* keys, mlx_device_info info) = NULL; +int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL; +int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL; +mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL; +bool (*mlx_distributed_is_available_)(void) = NULL; +mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL; int (*mlx_distributed_all_gather_)( mlx_array* res, const mlx_array x, @@ -323,11 +288,6 @@ int (*mlx_distributed_sum_scatter_)( const mlx_array x, const mlx_distributed_group group /* may be null */, const mlx_stream s) = NULL; -int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL; -int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL; -mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL; -bool (*mlx_distributed_is_available_)(void) = NULL; -mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL; void (*mlx_set_error_handler_)( mlx_error_handler_func handler, void* data, @@ -490,16 +450,6 @@ int (*mlx_fast_rope_)( int offset, const mlx_array freqs /* may be null */, const mlx_stream s) = NULL; -int (*mlx_fast_rope_dynamic_)( - mlx_array* res, - const mlx_array x, - int dims, - bool traditional, - mlx_optional_float base, - float scale, - const mlx_array offset, - const mlx_array freqs /* may be null */, - const mlx_stream s) = NULL; int (*mlx_fast_scaled_dot_product_attention_)( mlx_array* res, const mlx_array queries, @@ -610,6 +560,14 @@ int (*mlx_fft_rfftn_)( const int* axes, size_t axes_num, const mlx_stream s) = NULL; +mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable) = NULL; +int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io) = NULL; +int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io) = NULL; +int (*mlx_io_reader_free_)(mlx_io_reader io) = NULL; +mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable) = NULL; +int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io) = NULL; +int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io) = NULL; +int (*mlx_io_writer_free_)(mlx_io_writer io) = NULL; int (*mlx_load_reader_)( mlx_array* res, mlx_io_reader in_stream, @@ -635,14 +593,6 @@ int (*mlx_save_safetensors_)( const char* file, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata) = NULL; -mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable) = NULL; -int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io) = NULL; -int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io) = NULL; -int (*mlx_io_reader_free_)(mlx_io_reader io) = NULL; -mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable) = NULL; -int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io) = NULL; -int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io) = NULL; -int (*mlx_io_writer_free_)(mlx_io_writer io) = NULL; int (*mlx_linalg_cholesky_)( mlx_array* res, const mlx_array a, @@ -783,6 +733,7 @@ int (*mlx_reset_peak_memory_)(void) = NULL; int (*mlx_set_cache_limit_)(size_t* res, size_t limit) = NULL; int (*mlx_set_memory_limit_)(size_t* res, size_t limit) = NULL; int (*mlx_set_wired_limit_)(size_t* res, size_t limit) = NULL; +mlx_metal_device_info_t (*mlx_metal_device_info_)(void) = NULL; int (*mlx_metal_is_available_)(bool* res) = NULL; int (*mlx_metal_start_capture_)(const char* path) = NULL; int (*mlx_metal_stop_capture_)(void) = NULL; @@ -1211,14 +1162,6 @@ int (*mlx_gather_)( const int* slice_sizes, size_t slice_sizes_num, const mlx_stream s) = NULL; -int (*mlx_gather_single_)( - mlx_array* res, - const mlx_array a, - const mlx_array indices, - int axis, - const int* slice_sizes, - size_t slice_sizes_num, - const mlx_stream s) = NULL; int (*mlx_gather_mm_)( mlx_array* res, const mlx_array a, @@ -1540,15 +1483,6 @@ int (*mlx_put_along_axis_)( const mlx_array values, int axis, const mlx_stream s) = NULL; -int (*mlx_qqmm_)( - mlx_array* res, - const mlx_array x, - const mlx_array w, - const mlx_array w_scales /* may be null */, - mlx_optional_int group_size, - mlx_optional_int bits, - const char* mode, - const mlx_stream s) = NULL; int (*mlx_quantize_)( mlx_vector_array* res, const mlx_array w, @@ -1632,13 +1566,6 @@ int (*mlx_scatter_)( const int* axes, size_t axes_num, const mlx_stream s) = NULL; -int (*mlx_scatter_single_)( - mlx_array* res, - const mlx_array a, - const mlx_array indices, - const mlx_array updates, - int axis, - const mlx_stream s) = NULL; int (*mlx_scatter_add_)( mlx_array* res, const mlx_array a, @@ -1647,13 +1574,6 @@ int (*mlx_scatter_add_)( const int* axes, size_t axes_num, const mlx_stream s) = NULL; -int (*mlx_scatter_add_single_)( - mlx_array* res, - const mlx_array a, - const mlx_array indices, - const mlx_array updates, - int axis, - const mlx_stream s) = NULL; int (*mlx_scatter_add_axis_)( mlx_array* res, const mlx_array a, @@ -1669,13 +1589,6 @@ int (*mlx_scatter_max_)( const int* axes, size_t axes_num, const mlx_stream s) = NULL; -int (*mlx_scatter_max_single_)( - mlx_array* res, - const mlx_array a, - const mlx_array indices, - const mlx_array updates, - int axis, - const mlx_stream s) = NULL; int (*mlx_scatter_min_)( mlx_array* res, const mlx_array a, @@ -1684,13 +1597,6 @@ int (*mlx_scatter_min_)( const int* axes, size_t axes_num, const mlx_stream s) = NULL; -int (*mlx_scatter_min_single_)( - mlx_array* res, - const mlx_array a, - const mlx_array indices, - const mlx_array updates, - int axis, - const mlx_stream s) = NULL; int (*mlx_scatter_prod_)( mlx_array* res, const mlx_array a, @@ -1699,13 +1605,6 @@ int (*mlx_scatter_prod_)( const int* axes, size_t axes_num, const mlx_stream s) = NULL; -int (*mlx_scatter_prod_single_)( - mlx_array* res, - const mlx_array a, - const mlx_array indices, - const mlx_array updates, - int axis, - const mlx_stream s) = NULL; int (*mlx_segmented_mm_)( mlx_array* res, const mlx_array a, @@ -2129,6 +2028,22 @@ mlx_string (*mlx_string_new_data_)(const char* str) = NULL; int (*mlx_string_set_)(mlx_string* str, const mlx_string src) = NULL; const char * (*mlx_string_data_)(mlx_string str) = NULL; int (*mlx_string_free_)(mlx_string str) = NULL; +int (*mlx_detail_vmap_replace_)( + mlx_vector_array* res, + const mlx_vector_array inputs, + const mlx_vector_array s_inputs, + const mlx_vector_array s_outputs, + const int* in_axes, + size_t in_axes_num, + const int* out_axes, + size_t out_axes_num) = NULL; +int (*mlx_detail_vmap_trace_)( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + const mlx_closure fun, + const mlx_vector_array inputs, + const int* in_axes, + size_t in_axes_num) = NULL; int (*mlx_async_eval_)(const mlx_vector_array outputs) = NULL; int (*mlx_checkpoint_)(mlx_closure* res, const mlx_closure fun) = NULL; int (*mlx_custom_function_)( @@ -2159,22 +2074,6 @@ int (*mlx_vjp_)( const mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array cotangents) = NULL; -int (*mlx_detail_vmap_replace_)( - mlx_vector_array* res, - const mlx_vector_array inputs, - const mlx_vector_array s_inputs, - const mlx_vector_array s_outputs, - const int* in_axes, - size_t in_axes_num, - const int* out_axes, - size_t out_axes_num) = NULL; -int (*mlx_detail_vmap_trace_)( - mlx_vector_array* res_0, - mlx_vector_array* res_1, - const mlx_closure fun, - const mlx_vector_array inputs, - const int* in_axes, - size_t in_axes_num) = NULL; mlx_vector_array (*mlx_vector_array_new_)(void) = NULL; int (*mlx_vector_array_set_)(mlx_vector_array* vec, const mlx_vector_array src) = NULL; int (*mlx_vector_array_free_)(mlx_vector_array vec) = NULL; @@ -2267,8 +2166,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_array_new_double); CHECK_LOAD(handle, mlx_array_new_complex); CHECK_LOAD(handle, mlx_array_new_data); - CHECK_LOAD(handle, mlx_array_new_data_managed); - CHECK_LOAD(handle, mlx_array_new_data_managed_payload); CHECK_LOAD(handle, mlx_array_set); CHECK_LOAD(handle, mlx_array_set_bool); CHECK_LOAD(handle, mlx_array_set_int); @@ -2364,7 +2261,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_disable_compile); CHECK_LOAD(handle, mlx_enable_compile); CHECK_LOAD(handle, mlx_set_compile_mode); - CHECK_LOAD(handle, mlx_cuda_is_available); CHECK_LOAD(handle, mlx_device_new); CHECK_LOAD(handle, mlx_device_new_type); CHECK_LOAD(handle, mlx_device_free); @@ -2375,16 +2271,11 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_device_get_type); CHECK_LOAD(handle, mlx_get_default_device); CHECK_LOAD(handle, mlx_set_default_device); - CHECK_LOAD(handle, mlx_device_is_available); - CHECK_LOAD(handle, mlx_device_count); - CHECK_LOAD(handle, mlx_device_info_new); - CHECK_LOAD(handle, mlx_device_info_get); - CHECK_LOAD(handle, mlx_device_info_free); - CHECK_LOAD(handle, mlx_device_info_has_key); - CHECK_LOAD(handle, mlx_device_info_is_string); - CHECK_LOAD(handle, mlx_device_info_get_string); - CHECK_LOAD(handle, mlx_device_info_get_size); - CHECK_LOAD(handle, mlx_device_info_get_keys); + CHECK_LOAD(handle, mlx_distributed_group_rank); + CHECK_LOAD(handle, mlx_distributed_group_size); + CHECK_LOAD(handle, mlx_distributed_group_split); + CHECK_LOAD(handle, mlx_distributed_is_available); + CHECK_LOAD(handle, mlx_distributed_init); CHECK_LOAD(handle, mlx_distributed_all_gather); CHECK_LOAD(handle, mlx_distributed_all_max); CHECK_LOAD(handle, mlx_distributed_all_min); @@ -2393,11 +2284,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_distributed_recv_like); CHECK_LOAD(handle, mlx_distributed_send); CHECK_LOAD(handle, mlx_distributed_sum_scatter); - CHECK_LOAD(handle, mlx_distributed_group_rank); - CHECK_LOAD(handle, mlx_distributed_group_size); - CHECK_LOAD(handle, mlx_distributed_group_split); - CHECK_LOAD(handle, mlx_distributed_is_available); - CHECK_LOAD(handle, mlx_distributed_init); CHECK_LOAD(handle, mlx_set_error_handler); CHECK_LOAD(handle, _mlx_error); CHECK_LOAD(handle, mlx_export_function); @@ -2439,7 +2325,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_fast_metal_kernel_apply); CHECK_LOAD(handle, mlx_fast_rms_norm); CHECK_LOAD(handle, mlx_fast_rope); - CHECK_LOAD(handle, mlx_fast_rope_dynamic); CHECK_LOAD(handle, mlx_fast_scaled_dot_product_attention); CHECK_LOAD(handle, mlx_fft_fft); CHECK_LOAD(handle, mlx_fft_fft2); @@ -2455,14 +2340,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_fft_rfft); CHECK_LOAD(handle, mlx_fft_rfft2); CHECK_LOAD(handle, mlx_fft_rfftn); - CHECK_LOAD(handle, mlx_load_reader); - CHECK_LOAD(handle, mlx_load); - CHECK_LOAD(handle, mlx_load_safetensors_reader); - CHECK_LOAD(handle, mlx_load_safetensors); - CHECK_LOAD(handle, mlx_save_writer); - CHECK_LOAD(handle, mlx_save); - CHECK_LOAD(handle, mlx_save_safetensors_writer); - CHECK_LOAD(handle, mlx_save_safetensors); CHECK_LOAD(handle, mlx_io_reader_new); CHECK_LOAD(handle, mlx_io_reader_descriptor); CHECK_LOAD(handle, mlx_io_reader_tostring); @@ -2471,6 +2348,14 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_io_writer_descriptor); CHECK_LOAD(handle, mlx_io_writer_tostring); CHECK_LOAD(handle, mlx_io_writer_free); + CHECK_LOAD(handle, mlx_load_reader); + CHECK_LOAD(handle, mlx_load); + CHECK_LOAD(handle, mlx_load_safetensors_reader); + CHECK_LOAD(handle, mlx_load_safetensors); + CHECK_LOAD(handle, mlx_save_writer); + CHECK_LOAD(handle, mlx_save); + CHECK_LOAD(handle, mlx_save_safetensors_writer); + CHECK_LOAD(handle, mlx_save_safetensors); CHECK_LOAD(handle, mlx_linalg_cholesky); CHECK_LOAD(handle, mlx_linalg_cholesky_inv); CHECK_LOAD(handle, mlx_linalg_cross); @@ -2515,6 +2400,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_set_cache_limit); CHECK_LOAD(handle, mlx_set_memory_limit); CHECK_LOAD(handle, mlx_set_wired_limit); + CHECK_LOAD(handle, mlx_metal_device_info); CHECK_LOAD(handle, mlx_metal_is_available); CHECK_LOAD(handle, mlx_metal_start_capture); CHECK_LOAD(handle, mlx_metal_stop_capture); @@ -2600,7 +2486,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_full); CHECK_LOAD(handle, mlx_full_like); CHECK_LOAD(handle, mlx_gather); - CHECK_LOAD(handle, mlx_gather_single); CHECK_LOAD(handle, mlx_gather_mm); CHECK_LOAD(handle, mlx_gather_qmm); CHECK_LOAD(handle, mlx_greater); @@ -2665,7 +2550,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_prod_axis); CHECK_LOAD(handle, mlx_prod); CHECK_LOAD(handle, mlx_put_along_axis); - CHECK_LOAD(handle, mlx_qqmm); CHECK_LOAD(handle, mlx_quantize); CHECK_LOAD(handle, mlx_quantized_matmul); CHECK_LOAD(handle, mlx_radians); @@ -2682,16 +2566,11 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_round); CHECK_LOAD(handle, mlx_rsqrt); CHECK_LOAD(handle, mlx_scatter); - CHECK_LOAD(handle, mlx_scatter_single); CHECK_LOAD(handle, mlx_scatter_add); - CHECK_LOAD(handle, mlx_scatter_add_single); CHECK_LOAD(handle, mlx_scatter_add_axis); CHECK_LOAD(handle, mlx_scatter_max); - CHECK_LOAD(handle, mlx_scatter_max_single); CHECK_LOAD(handle, mlx_scatter_min); - CHECK_LOAD(handle, mlx_scatter_min_single); CHECK_LOAD(handle, mlx_scatter_prod); - CHECK_LOAD(handle, mlx_scatter_prod_single); CHECK_LOAD(handle, mlx_segmented_mm); CHECK_LOAD(handle, mlx_sigmoid); CHECK_LOAD(handle, mlx_sign); @@ -2786,6 +2665,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_string_set); CHECK_LOAD(handle, mlx_string_data); CHECK_LOAD(handle, mlx_string_free); + CHECK_LOAD(handle, mlx_detail_vmap_replace); + CHECK_LOAD(handle, mlx_detail_vmap_trace); CHECK_LOAD(handle, mlx_async_eval); CHECK_LOAD(handle, mlx_checkpoint); CHECK_LOAD(handle, mlx_custom_function); @@ -2794,8 +2675,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_jvp); CHECK_LOAD(handle, mlx_value_and_grad); CHECK_LOAD(handle, mlx_vjp); - CHECK_LOAD(handle, mlx_detail_vmap_replace); - CHECK_LOAD(handle, mlx_detail_vmap_trace); CHECK_LOAD(handle, mlx_vector_array_new); CHECK_LOAD(handle, mlx_vector_array_set); CHECK_LOAD(handle, mlx_vector_array_free); diff --git a/x/mlxrunner/mlx/generated.h b/x/mlxrunner/mlx/generated.h index e8dfa7b90..c88946d9f 100644 --- a/x/mlxrunner/mlx/generated.h +++ b/x/mlxrunner/mlx/generated.h @@ -17,8 +17,6 @@ #define mlx_array_new_double mlx_array_new_double_mlx_gen_orig_ #define mlx_array_new_complex mlx_array_new_complex_mlx_gen_orig_ #define mlx_array_new_data mlx_array_new_data_mlx_gen_orig_ -#define mlx_array_new_data_managed mlx_array_new_data_managed_mlx_gen_orig_ -#define mlx_array_new_data_managed_payload mlx_array_new_data_managed_payload_mlx_gen_orig_ #define mlx_array_set mlx_array_set_mlx_gen_orig_ #define mlx_array_set_bool mlx_array_set_bool_mlx_gen_orig_ #define mlx_array_set_int mlx_array_set_int_mlx_gen_orig_ @@ -114,7 +112,6 @@ #define mlx_disable_compile mlx_disable_compile_mlx_gen_orig_ #define mlx_enable_compile mlx_enable_compile_mlx_gen_orig_ #define mlx_set_compile_mode mlx_set_compile_mode_mlx_gen_orig_ -#define mlx_cuda_is_available mlx_cuda_is_available_mlx_gen_orig_ #define mlx_device_new mlx_device_new_mlx_gen_orig_ #define mlx_device_new_type mlx_device_new_type_mlx_gen_orig_ #define mlx_device_free mlx_device_free_mlx_gen_orig_ @@ -125,16 +122,11 @@ #define mlx_device_get_type mlx_device_get_type_mlx_gen_orig_ #define mlx_get_default_device mlx_get_default_device_mlx_gen_orig_ #define mlx_set_default_device mlx_set_default_device_mlx_gen_orig_ -#define mlx_device_is_available mlx_device_is_available_mlx_gen_orig_ -#define mlx_device_count mlx_device_count_mlx_gen_orig_ -#define mlx_device_info_new mlx_device_info_new_mlx_gen_orig_ -#define mlx_device_info_get mlx_device_info_get_mlx_gen_orig_ -#define mlx_device_info_free mlx_device_info_free_mlx_gen_orig_ -#define mlx_device_info_has_key mlx_device_info_has_key_mlx_gen_orig_ -#define mlx_device_info_is_string mlx_device_info_is_string_mlx_gen_orig_ -#define mlx_device_info_get_string mlx_device_info_get_string_mlx_gen_orig_ -#define mlx_device_info_get_size mlx_device_info_get_size_mlx_gen_orig_ -#define mlx_device_info_get_keys mlx_device_info_get_keys_mlx_gen_orig_ +#define mlx_distributed_group_rank mlx_distributed_group_rank_mlx_gen_orig_ +#define mlx_distributed_group_size mlx_distributed_group_size_mlx_gen_orig_ +#define mlx_distributed_group_split mlx_distributed_group_split_mlx_gen_orig_ +#define mlx_distributed_is_available mlx_distributed_is_available_mlx_gen_orig_ +#define mlx_distributed_init mlx_distributed_init_mlx_gen_orig_ #define mlx_distributed_all_gather mlx_distributed_all_gather_mlx_gen_orig_ #define mlx_distributed_all_max mlx_distributed_all_max_mlx_gen_orig_ #define mlx_distributed_all_min mlx_distributed_all_min_mlx_gen_orig_ @@ -143,11 +135,6 @@ #define mlx_distributed_recv_like mlx_distributed_recv_like_mlx_gen_orig_ #define mlx_distributed_send mlx_distributed_send_mlx_gen_orig_ #define mlx_distributed_sum_scatter mlx_distributed_sum_scatter_mlx_gen_orig_ -#define mlx_distributed_group_rank mlx_distributed_group_rank_mlx_gen_orig_ -#define mlx_distributed_group_size mlx_distributed_group_size_mlx_gen_orig_ -#define mlx_distributed_group_split mlx_distributed_group_split_mlx_gen_orig_ -#define mlx_distributed_is_available mlx_distributed_is_available_mlx_gen_orig_ -#define mlx_distributed_init mlx_distributed_init_mlx_gen_orig_ #define mlx_set_error_handler mlx_set_error_handler_mlx_gen_orig_ #define _mlx_error _mlx_error_mlx_gen_orig_ #define mlx_export_function mlx_export_function_mlx_gen_orig_ @@ -189,7 +176,6 @@ #define mlx_fast_metal_kernel_apply mlx_fast_metal_kernel_apply_mlx_gen_orig_ #define mlx_fast_rms_norm mlx_fast_rms_norm_mlx_gen_orig_ #define mlx_fast_rope mlx_fast_rope_mlx_gen_orig_ -#define mlx_fast_rope_dynamic mlx_fast_rope_dynamic_mlx_gen_orig_ #define mlx_fast_scaled_dot_product_attention mlx_fast_scaled_dot_product_attention_mlx_gen_orig_ #define mlx_fft_fft mlx_fft_fft_mlx_gen_orig_ #define mlx_fft_fft2 mlx_fft_fft2_mlx_gen_orig_ @@ -205,14 +191,6 @@ #define mlx_fft_rfft mlx_fft_rfft_mlx_gen_orig_ #define mlx_fft_rfft2 mlx_fft_rfft2_mlx_gen_orig_ #define mlx_fft_rfftn mlx_fft_rfftn_mlx_gen_orig_ -#define mlx_load_reader mlx_load_reader_mlx_gen_orig_ -#define mlx_load mlx_load_mlx_gen_orig_ -#define mlx_load_safetensors_reader mlx_load_safetensors_reader_mlx_gen_orig_ -#define mlx_load_safetensors mlx_load_safetensors_mlx_gen_orig_ -#define mlx_save_writer mlx_save_writer_mlx_gen_orig_ -#define mlx_save mlx_save_mlx_gen_orig_ -#define mlx_save_safetensors_writer mlx_save_safetensors_writer_mlx_gen_orig_ -#define mlx_save_safetensors mlx_save_safetensors_mlx_gen_orig_ #define mlx_io_reader_new mlx_io_reader_new_mlx_gen_orig_ #define mlx_io_reader_descriptor mlx_io_reader_descriptor_mlx_gen_orig_ #define mlx_io_reader_tostring mlx_io_reader_tostring_mlx_gen_orig_ @@ -221,6 +199,14 @@ #define mlx_io_writer_descriptor mlx_io_writer_descriptor_mlx_gen_orig_ #define mlx_io_writer_tostring mlx_io_writer_tostring_mlx_gen_orig_ #define mlx_io_writer_free mlx_io_writer_free_mlx_gen_orig_ +#define mlx_load_reader mlx_load_reader_mlx_gen_orig_ +#define mlx_load mlx_load_mlx_gen_orig_ +#define mlx_load_safetensors_reader mlx_load_safetensors_reader_mlx_gen_orig_ +#define mlx_load_safetensors mlx_load_safetensors_mlx_gen_orig_ +#define mlx_save_writer mlx_save_writer_mlx_gen_orig_ +#define mlx_save mlx_save_mlx_gen_orig_ +#define mlx_save_safetensors_writer mlx_save_safetensors_writer_mlx_gen_orig_ +#define mlx_save_safetensors mlx_save_safetensors_mlx_gen_orig_ #define mlx_linalg_cholesky mlx_linalg_cholesky_mlx_gen_orig_ #define mlx_linalg_cholesky_inv mlx_linalg_cholesky_inv_mlx_gen_orig_ #define mlx_linalg_cross mlx_linalg_cross_mlx_gen_orig_ @@ -265,6 +251,7 @@ #define mlx_set_cache_limit mlx_set_cache_limit_mlx_gen_orig_ #define mlx_set_memory_limit mlx_set_memory_limit_mlx_gen_orig_ #define mlx_set_wired_limit mlx_set_wired_limit_mlx_gen_orig_ +#define mlx_metal_device_info mlx_metal_device_info_mlx_gen_orig_ #define mlx_metal_is_available mlx_metal_is_available_mlx_gen_orig_ #define mlx_metal_start_capture mlx_metal_start_capture_mlx_gen_orig_ #define mlx_metal_stop_capture mlx_metal_stop_capture_mlx_gen_orig_ @@ -350,7 +337,6 @@ #define mlx_full mlx_full_mlx_gen_orig_ #define mlx_full_like mlx_full_like_mlx_gen_orig_ #define mlx_gather mlx_gather_mlx_gen_orig_ -#define mlx_gather_single mlx_gather_single_mlx_gen_orig_ #define mlx_gather_mm mlx_gather_mm_mlx_gen_orig_ #define mlx_gather_qmm mlx_gather_qmm_mlx_gen_orig_ #define mlx_greater mlx_greater_mlx_gen_orig_ @@ -415,7 +401,6 @@ #define mlx_prod_axis mlx_prod_axis_mlx_gen_orig_ #define mlx_prod mlx_prod_mlx_gen_orig_ #define mlx_put_along_axis mlx_put_along_axis_mlx_gen_orig_ -#define mlx_qqmm mlx_qqmm_mlx_gen_orig_ #define mlx_quantize mlx_quantize_mlx_gen_orig_ #define mlx_quantized_matmul mlx_quantized_matmul_mlx_gen_orig_ #define mlx_radians mlx_radians_mlx_gen_orig_ @@ -432,16 +417,11 @@ #define mlx_round mlx_round_mlx_gen_orig_ #define mlx_rsqrt mlx_rsqrt_mlx_gen_orig_ #define mlx_scatter mlx_scatter_mlx_gen_orig_ -#define mlx_scatter_single mlx_scatter_single_mlx_gen_orig_ #define mlx_scatter_add mlx_scatter_add_mlx_gen_orig_ -#define mlx_scatter_add_single mlx_scatter_add_single_mlx_gen_orig_ #define mlx_scatter_add_axis mlx_scatter_add_axis_mlx_gen_orig_ #define mlx_scatter_max mlx_scatter_max_mlx_gen_orig_ -#define mlx_scatter_max_single mlx_scatter_max_single_mlx_gen_orig_ #define mlx_scatter_min mlx_scatter_min_mlx_gen_orig_ -#define mlx_scatter_min_single mlx_scatter_min_single_mlx_gen_orig_ #define mlx_scatter_prod mlx_scatter_prod_mlx_gen_orig_ -#define mlx_scatter_prod_single mlx_scatter_prod_single_mlx_gen_orig_ #define mlx_segmented_mm mlx_segmented_mm_mlx_gen_orig_ #define mlx_sigmoid mlx_sigmoid_mlx_gen_orig_ #define mlx_sign mlx_sign_mlx_gen_orig_ @@ -536,6 +516,8 @@ #define mlx_string_set mlx_string_set_mlx_gen_orig_ #define mlx_string_data mlx_string_data_mlx_gen_orig_ #define mlx_string_free mlx_string_free_mlx_gen_orig_ +#define mlx_detail_vmap_replace mlx_detail_vmap_replace_mlx_gen_orig_ +#define mlx_detail_vmap_trace mlx_detail_vmap_trace_mlx_gen_orig_ #define mlx_async_eval mlx_async_eval_mlx_gen_orig_ #define mlx_checkpoint mlx_checkpoint_mlx_gen_orig_ #define mlx_custom_function mlx_custom_function_mlx_gen_orig_ @@ -544,8 +526,6 @@ #define mlx_jvp mlx_jvp_mlx_gen_orig_ #define mlx_value_and_grad mlx_value_and_grad_mlx_gen_orig_ #define mlx_vjp mlx_vjp_mlx_gen_orig_ -#define mlx_detail_vmap_replace mlx_detail_vmap_replace_mlx_gen_orig_ -#define mlx_detail_vmap_trace mlx_detail_vmap_trace_mlx_gen_orig_ #define mlx_vector_array_new mlx_vector_array_new_mlx_gen_orig_ #define mlx_vector_array_set mlx_vector_array_set_mlx_gen_orig_ #define mlx_vector_array_free mlx_vector_array_free_mlx_gen_orig_ @@ -606,8 +586,6 @@ #undef mlx_array_new_double #undef mlx_array_new_complex #undef mlx_array_new_data -#undef mlx_array_new_data_managed -#undef mlx_array_new_data_managed_payload #undef mlx_array_set #undef mlx_array_set_bool #undef mlx_array_set_int @@ -703,7 +681,6 @@ #undef mlx_disable_compile #undef mlx_enable_compile #undef mlx_set_compile_mode -#undef mlx_cuda_is_available #undef mlx_device_new #undef mlx_device_new_type #undef mlx_device_free @@ -714,16 +691,11 @@ #undef mlx_device_get_type #undef mlx_get_default_device #undef mlx_set_default_device -#undef mlx_device_is_available -#undef mlx_device_count -#undef mlx_device_info_new -#undef mlx_device_info_get -#undef mlx_device_info_free -#undef mlx_device_info_has_key -#undef mlx_device_info_is_string -#undef mlx_device_info_get_string -#undef mlx_device_info_get_size -#undef mlx_device_info_get_keys +#undef mlx_distributed_group_rank +#undef mlx_distributed_group_size +#undef mlx_distributed_group_split +#undef mlx_distributed_is_available +#undef mlx_distributed_init #undef mlx_distributed_all_gather #undef mlx_distributed_all_max #undef mlx_distributed_all_min @@ -732,11 +704,6 @@ #undef mlx_distributed_recv_like #undef mlx_distributed_send #undef mlx_distributed_sum_scatter -#undef mlx_distributed_group_rank -#undef mlx_distributed_group_size -#undef mlx_distributed_group_split -#undef mlx_distributed_is_available -#undef mlx_distributed_init #undef mlx_set_error_handler #undef _mlx_error #undef mlx_export_function @@ -778,7 +745,6 @@ #undef mlx_fast_metal_kernel_apply #undef mlx_fast_rms_norm #undef mlx_fast_rope -#undef mlx_fast_rope_dynamic #undef mlx_fast_scaled_dot_product_attention #undef mlx_fft_fft #undef mlx_fft_fft2 @@ -794,14 +760,6 @@ #undef mlx_fft_rfft #undef mlx_fft_rfft2 #undef mlx_fft_rfftn -#undef mlx_load_reader -#undef mlx_load -#undef mlx_load_safetensors_reader -#undef mlx_load_safetensors -#undef mlx_save_writer -#undef mlx_save -#undef mlx_save_safetensors_writer -#undef mlx_save_safetensors #undef mlx_io_reader_new #undef mlx_io_reader_descriptor #undef mlx_io_reader_tostring @@ -810,6 +768,14 @@ #undef mlx_io_writer_descriptor #undef mlx_io_writer_tostring #undef mlx_io_writer_free +#undef mlx_load_reader +#undef mlx_load +#undef mlx_load_safetensors_reader +#undef mlx_load_safetensors +#undef mlx_save_writer +#undef mlx_save +#undef mlx_save_safetensors_writer +#undef mlx_save_safetensors #undef mlx_linalg_cholesky #undef mlx_linalg_cholesky_inv #undef mlx_linalg_cross @@ -854,6 +820,7 @@ #undef mlx_set_cache_limit #undef mlx_set_memory_limit #undef mlx_set_wired_limit +#undef mlx_metal_device_info #undef mlx_metal_is_available #undef mlx_metal_start_capture #undef mlx_metal_stop_capture @@ -939,7 +906,6 @@ #undef mlx_full #undef mlx_full_like #undef mlx_gather -#undef mlx_gather_single #undef mlx_gather_mm #undef mlx_gather_qmm #undef mlx_greater @@ -1004,7 +970,6 @@ #undef mlx_prod_axis #undef mlx_prod #undef mlx_put_along_axis -#undef mlx_qqmm #undef mlx_quantize #undef mlx_quantized_matmul #undef mlx_radians @@ -1021,16 +986,11 @@ #undef mlx_round #undef mlx_rsqrt #undef mlx_scatter -#undef mlx_scatter_single #undef mlx_scatter_add -#undef mlx_scatter_add_single #undef mlx_scatter_add_axis #undef mlx_scatter_max -#undef mlx_scatter_max_single #undef mlx_scatter_min -#undef mlx_scatter_min_single #undef mlx_scatter_prod -#undef mlx_scatter_prod_single #undef mlx_segmented_mm #undef mlx_sigmoid #undef mlx_sign @@ -1125,6 +1085,8 @@ #undef mlx_string_set #undef mlx_string_data #undef mlx_string_free +#undef mlx_detail_vmap_replace +#undef mlx_detail_vmap_trace #undef mlx_async_eval #undef mlx_checkpoint #undef mlx_custom_function @@ -1133,8 +1095,6 @@ #undef mlx_jvp #undef mlx_value_and_grad #undef mlx_vjp -#undef mlx_detail_vmap_replace -#undef mlx_detail_vmap_trace #undef mlx_vector_array_new #undef mlx_vector_array_set #undef mlx_vector_array_free @@ -1197,19 +1157,6 @@ extern mlx_array (*mlx_array_new_data_)( const int* shape, int dim, mlx_dtype dtype); -extern mlx_array (*mlx_array_new_data_managed_)( - void* data, - const int* shape, - int dim, - mlx_dtype dtype, - void (*dtor)(void*)); -extern mlx_array (*mlx_array_new_data_managed_payload_)( - void* data, - const int* shape, - int dim, - mlx_dtype dtype, - void* payload, - void (*dtor)(void*)); extern int (*mlx_array_set_)(mlx_array* arr, const mlx_array src); extern int (*mlx_array_set_bool_)(mlx_array* arr, bool val); extern int (*mlx_array_set_int_)(mlx_array* arr, int val); @@ -1244,7 +1191,7 @@ extern int (*mlx_array_item_int32_)(int32_t* res, const mlx_array arr); extern int (*mlx_array_item_int64_)(int64_t* res, const mlx_array arr); extern int (*mlx_array_item_float32_)(float* res, const mlx_array arr); extern int (*mlx_array_item_float64_)(double* res, const mlx_array arr); -extern int (*mlx_array_item_complex64_)(mlx_complex64_t* res, const mlx_array arr); +extern int (*mlx_array_item_complex64_)(float _Complex* res, const mlx_array arr); extern int (*mlx_array_item_float16_)(float16_t* res, const mlx_array arr); extern int (*mlx_array_item_bfloat16_)(bfloat16_t* res, const mlx_array arr); extern const bool * (*mlx_array_data_bool_)(const mlx_array arr); @@ -1258,7 +1205,7 @@ extern const int32_t * (*mlx_array_data_int32_)(const mlx_array arr); extern const int64_t * (*mlx_array_data_int64_)(const mlx_array arr); extern const float * (*mlx_array_data_float32_)(const mlx_array arr); extern const double * (*mlx_array_data_float64_)(const mlx_array arr); -extern const mlx_complex64_t * (*mlx_array_data_complex64_)(const mlx_array arr); +extern const float _Complex * (*mlx_array_data_complex64_)(const mlx_array arr); extern const float16_t * (*mlx_array_data_float16_)(const mlx_array arr); extern const bfloat16_t * (*mlx_array_data_bfloat16_)(const mlx_array arr); extern int (*_mlx_array_is_available_)(bool* res, const mlx_array arr); @@ -1282,11 +1229,10 @@ extern int (*mlx_closure_apply_)( extern mlx_closure (*mlx_closure_new_unary_)(int (*fun)(mlx_array*, const mlx_array)); extern mlx_closure_kwargs (*mlx_closure_kwargs_new_)(void); extern int (*mlx_closure_kwargs_free_)(mlx_closure_kwargs cls); -extern mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)( - int (*fun)( - mlx_vector_array*, - const mlx_vector_array, - const mlx_map_string_to_array)); +extern mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)(int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_map_string_to_array)); extern mlx_closure_kwargs (*mlx_closure_kwargs_new_func_payload_)( int (*fun)( mlx_vector_array*, @@ -1325,12 +1271,11 @@ extern int (*mlx_closure_value_and_grad_apply_)( const mlx_vector_array input); extern mlx_closure_custom (*mlx_closure_custom_new_)(void); extern int (*mlx_closure_custom_free_)(mlx_closure_custom cls); -extern mlx_closure_custom (*mlx_closure_custom_new_func_)( - int (*fun)( - mlx_vector_array*, - const mlx_vector_array, - const mlx_vector_array, - const mlx_vector_array)); +extern mlx_closure_custom (*mlx_closure_custom_new_func_)(int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const mlx_vector_array)); extern mlx_closure_custom (*mlx_closure_custom_new_func_payload_)( int (*fun)( mlx_vector_array*, @@ -1351,13 +1296,12 @@ extern int (*mlx_closure_custom_apply_)( const mlx_vector_array input_2); extern mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_)(void); extern int (*mlx_closure_custom_jvp_free_)(mlx_closure_custom_jvp cls); -extern mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)( - int (*fun)( - mlx_vector_array*, - const mlx_vector_array, - const mlx_vector_array, - const int*, - size_t _num)); +extern mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)(int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const int*, + size_t _num)); extern mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_payload_)( int (*fun)( mlx_vector_array*, @@ -1380,13 +1324,12 @@ extern int (*mlx_closure_custom_jvp_apply_)( size_t input_2_num); extern mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_)(void); extern int (*mlx_closure_custom_vmap_free_)(mlx_closure_custom_vmap cls); -extern mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)( - int (*fun)( - mlx_vector_array*, - mlx_vector_int*, - const mlx_vector_array, - const int*, - size_t _num)); +extern mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)(int (*fun)( + mlx_vector_array*, + mlx_vector_int*, + const mlx_vector_array, + const int*, + size_t _num)); extern mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_payload_)( int (*fun)( mlx_vector_array*, @@ -1420,7 +1363,6 @@ extern int (*mlx_detail_compile_erase_)(uintptr_t fun_id); extern int (*mlx_disable_compile_)(void); extern int (*mlx_enable_compile_)(void); extern int (*mlx_set_compile_mode_)(mlx_compile_mode mode); -extern int (*mlx_cuda_is_available_)(bool* res); extern mlx_device (*mlx_device_new_)(void); extern mlx_device (*mlx_device_new_type_)(mlx_device_type type, int index); extern int (*mlx_device_free_)(mlx_device dev); @@ -1431,28 +1373,11 @@ extern int (*mlx_device_get_index_)(int* index, mlx_device dev); extern int (*mlx_device_get_type_)(mlx_device_type* type, mlx_device dev); extern int (*mlx_get_default_device_)(mlx_device* dev); extern int (*mlx_set_default_device_)(mlx_device dev); -extern int (*mlx_device_is_available_)(bool* avail, mlx_device dev); -extern int (*mlx_device_count_)(int* count, mlx_device_type type); -extern mlx_device_info (*mlx_device_info_new_)(void); -extern int (*mlx_device_info_get_)(mlx_device_info* info, mlx_device dev); -extern int (*mlx_device_info_free_)(mlx_device_info info); -extern int (*mlx_device_info_has_key_)( - bool* exists, - mlx_device_info info, - const char* key); -extern int (*mlx_device_info_is_string_)( - bool* is_string, - mlx_device_info info, - const char* key); -extern int (*mlx_device_info_get_string_)( - const char** value, - mlx_device_info info, - const char* key); -extern int (*mlx_device_info_get_size_)( - size_t* value, - mlx_device_info info, - const char* key); -extern int (*mlx_device_info_get_keys_)(mlx_vector_string* keys, mlx_device_info info); +extern int (*mlx_distributed_group_rank_)(mlx_distributed_group group); +extern int (*mlx_distributed_group_size_)(mlx_distributed_group group); +extern mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key); +extern bool (*mlx_distributed_is_available_)(void); +extern mlx_distributed_group (*mlx_distributed_init_)(bool strict); extern int (*mlx_distributed_all_gather_)( mlx_array* res, const mlx_array x, @@ -1498,11 +1423,6 @@ extern int (*mlx_distributed_sum_scatter_)( const mlx_array x, const mlx_distributed_group group /* may be null */, const mlx_stream s); -extern int (*mlx_distributed_group_rank_)(mlx_distributed_group group); -extern int (*mlx_distributed_group_size_)(mlx_distributed_group group); -extern mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key); -extern bool (*mlx_distributed_is_available_)(void); -extern mlx_distributed_group (*mlx_distributed_init_)(bool strict); extern void (*mlx_set_error_handler_)( mlx_error_handler_func handler, void* data, @@ -1665,16 +1585,6 @@ extern int (*mlx_fast_rope_)( int offset, const mlx_array freqs /* may be null */, const mlx_stream s); -extern int (*mlx_fast_rope_dynamic_)( - mlx_array* res, - const mlx_array x, - int dims, - bool traditional, - mlx_optional_float base, - float scale, - const mlx_array offset, - const mlx_array freqs /* may be null */, - const mlx_stream s); extern int (*mlx_fast_scaled_dot_product_attention_)( mlx_array* res, const mlx_array queries, @@ -1785,6 +1695,14 @@ extern int (*mlx_fft_rfftn_)( const int* axes, size_t axes_num, const mlx_stream s); +extern mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable); +extern int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io); +extern int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io); +extern int (*mlx_io_reader_free_)(mlx_io_reader io); +extern mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable); +extern int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io); +extern int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io); +extern int (*mlx_io_writer_free_)(mlx_io_writer io); extern int (*mlx_load_reader_)( mlx_array* res, mlx_io_reader in_stream, @@ -1810,14 +1728,6 @@ extern int (*mlx_save_safetensors_)( const char* file, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata); -extern mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable); -extern int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io); -extern int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io); -extern int (*mlx_io_reader_free_)(mlx_io_reader io); -extern mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable); -extern int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io); -extern int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io); -extern int (*mlx_io_writer_free_)(mlx_io_writer io); extern int (*mlx_linalg_cholesky_)( mlx_array* res, const mlx_array a, @@ -1958,6 +1868,7 @@ extern int (*mlx_reset_peak_memory_)(void); extern int (*mlx_set_cache_limit_)(size_t* res, size_t limit); extern int (*mlx_set_memory_limit_)(size_t* res, size_t limit); extern int (*mlx_set_wired_limit_)(size_t* res, size_t limit); +extern mlx_metal_device_info_t (*mlx_metal_device_info_)(void); extern int (*mlx_metal_is_available_)(bool* res); extern int (*mlx_metal_start_capture_)(const char* path); extern int (*mlx_metal_stop_capture_)(void); @@ -2386,14 +2297,6 @@ extern int (*mlx_gather_)( const int* slice_sizes, size_t slice_sizes_num, const mlx_stream s); -extern int (*mlx_gather_single_)( - mlx_array* res, - const mlx_array a, - const mlx_array indices, - int axis, - const int* slice_sizes, - size_t slice_sizes_num, - const mlx_stream s); extern int (*mlx_gather_mm_)( mlx_array* res, const mlx_array a, @@ -2715,15 +2618,6 @@ extern int (*mlx_put_along_axis_)( const mlx_array values, int axis, const mlx_stream s); -extern int (*mlx_qqmm_)( - mlx_array* res, - const mlx_array x, - const mlx_array w, - const mlx_array w_scales /* may be null */, - mlx_optional_int group_size, - mlx_optional_int bits, - const char* mode, - const mlx_stream s); extern int (*mlx_quantize_)( mlx_vector_array* res, const mlx_array w, @@ -2807,13 +2701,6 @@ extern int (*mlx_scatter_)( const int* axes, size_t axes_num, const mlx_stream s); -extern int (*mlx_scatter_single_)( - mlx_array* res, - const mlx_array a, - const mlx_array indices, - const mlx_array updates, - int axis, - const mlx_stream s); extern int (*mlx_scatter_add_)( mlx_array* res, const mlx_array a, @@ -2822,13 +2709,6 @@ extern int (*mlx_scatter_add_)( const int* axes, size_t axes_num, const mlx_stream s); -extern int (*mlx_scatter_add_single_)( - mlx_array* res, - const mlx_array a, - const mlx_array indices, - const mlx_array updates, - int axis, - const mlx_stream s); extern int (*mlx_scatter_add_axis_)( mlx_array* res, const mlx_array a, @@ -2844,13 +2724,6 @@ extern int (*mlx_scatter_max_)( const int* axes, size_t axes_num, const mlx_stream s); -extern int (*mlx_scatter_max_single_)( - mlx_array* res, - const mlx_array a, - const mlx_array indices, - const mlx_array updates, - int axis, - const mlx_stream s); extern int (*mlx_scatter_min_)( mlx_array* res, const mlx_array a, @@ -2859,13 +2732,6 @@ extern int (*mlx_scatter_min_)( const int* axes, size_t axes_num, const mlx_stream s); -extern int (*mlx_scatter_min_single_)( - mlx_array* res, - const mlx_array a, - const mlx_array indices, - const mlx_array updates, - int axis, - const mlx_stream s); extern int (*mlx_scatter_prod_)( mlx_array* res, const mlx_array a, @@ -2874,13 +2740,6 @@ extern int (*mlx_scatter_prod_)( const int* axes, size_t axes_num, const mlx_stream s); -extern int (*mlx_scatter_prod_single_)( - mlx_array* res, - const mlx_array a, - const mlx_array indices, - const mlx_array updates, - int axis, - const mlx_stream s); extern int (*mlx_segmented_mm_)( mlx_array* res, const mlx_array a, @@ -3304,6 +3163,22 @@ extern mlx_string (*mlx_string_new_data_)(const char* str); extern int (*mlx_string_set_)(mlx_string* str, const mlx_string src); extern const char * (*mlx_string_data_)(mlx_string str); extern int (*mlx_string_free_)(mlx_string str); +extern int (*mlx_detail_vmap_replace_)( + mlx_vector_array* res, + const mlx_vector_array inputs, + const mlx_vector_array s_inputs, + const mlx_vector_array s_outputs, + const int* in_axes, + size_t in_axes_num, + const int* out_axes, + size_t out_axes_num); +extern int (*mlx_detail_vmap_trace_)( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + const mlx_closure fun, + const mlx_vector_array inputs, + const int* in_axes, + size_t in_axes_num); extern int (*mlx_async_eval_)(const mlx_vector_array outputs); extern int (*mlx_checkpoint_)(mlx_closure* res, const mlx_closure fun); extern int (*mlx_custom_function_)( @@ -3334,22 +3209,6 @@ extern int (*mlx_vjp_)( const mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array cotangents); -extern int (*mlx_detail_vmap_replace_)( - mlx_vector_array* res, - const mlx_vector_array inputs, - const mlx_vector_array s_inputs, - const mlx_vector_array s_outputs, - const int* in_axes, - size_t in_axes_num, - const int* out_axes, - size_t out_axes_num); -extern int (*mlx_detail_vmap_trace_)( - mlx_vector_array* res_0, - mlx_vector_array* res_1, - const mlx_closure fun, - const mlx_vector_array inputs, - const int* in_axes, - size_t in_axes_num); extern mlx_vector_array (*mlx_vector_array_new_)(void); extern int (*mlx_vector_array_set_)(mlx_vector_array* vec, const mlx_vector_array src); extern int (*mlx_vector_array_free_)(mlx_vector_array vec); @@ -3434,36 +3293,47 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle); static inline size_t mlx_dtype_size(mlx_dtype dtype) { return mlx_dtype_size_(dtype); } + static inline int mlx_array_tostring(mlx_string* str, const mlx_array arr) { return mlx_array_tostring_(str, arr); } + static inline mlx_array mlx_array_new(void) { return mlx_array_new_(); } + static inline int mlx_array_free(mlx_array arr) { return mlx_array_free_(arr); } + static inline mlx_array mlx_array_new_bool(bool val) { return mlx_array_new_bool_(val); } + static inline mlx_array mlx_array_new_int(int val) { return mlx_array_new_int_(val); } + static inline mlx_array mlx_array_new_float32(float val) { return mlx_array_new_float32_(val); } + static inline mlx_array mlx_array_new_float(float val) { return mlx_array_new_float_(val); } + static inline mlx_array mlx_array_new_float64(double val) { return mlx_array_new_float64_(val); } + static inline mlx_array mlx_array_new_double(double val) { return mlx_array_new_double_(val); } + static inline mlx_array mlx_array_new_complex(float real_val, float imag_val) { return mlx_array_new_complex_(real_val, imag_val); } + static inline mlx_array mlx_array_new_data( const void* data, const int* shape, @@ -3471,47 +3341,39 @@ static inline mlx_array mlx_array_new_data( mlx_dtype dtype) { return mlx_array_new_data_(data, shape, dim, dtype); } -static inline mlx_array mlx_array_new_data_managed( - void* data, - const int* shape, - int dim, - mlx_dtype dtype, - void (*dtor)(void*)) { - return mlx_array_new_data_managed_(data, shape, dim, dtype, dtor); -} -static inline mlx_array mlx_array_new_data_managed_payload( - void* data, - const int* shape, - int dim, - mlx_dtype dtype, - void* payload, - void (*dtor)(void*)) { - return mlx_array_new_data_managed_payload_(data, shape, dim, dtype, payload, dtor); -} + static inline int mlx_array_set(mlx_array* arr, const mlx_array src) { return mlx_array_set_(arr, src); } + static inline int mlx_array_set_bool(mlx_array* arr, bool val) { return mlx_array_set_bool_(arr, val); } + static inline int mlx_array_set_int(mlx_array* arr, int val) { return mlx_array_set_int_(arr, val); } + static inline int mlx_array_set_float32(mlx_array* arr, float val) { return mlx_array_set_float32_(arr, val); } + static inline int mlx_array_set_float(mlx_array* arr, float val) { return mlx_array_set_float_(arr, val); } + static inline int mlx_array_set_float64(mlx_array* arr, double val) { return mlx_array_set_float64_(arr, val); } + static inline int mlx_array_set_double(mlx_array* arr, double val) { return mlx_array_set_double_(arr, val); } + static inline int mlx_array_set_complex(mlx_array* arr, float real_val, float imag_val) { return mlx_array_set_complex_(arr, real_val, imag_val); } + static inline int mlx_array_set_data( mlx_array* arr, const void* data, @@ -3520,173 +3382,225 @@ static inline int mlx_array_set_data( mlx_dtype dtype) { return mlx_array_set_data_(arr, data, shape, dim, dtype); } + static inline size_t mlx_array_itemsize(const mlx_array arr) { return mlx_array_itemsize_(arr); } + static inline size_t mlx_array_size(const mlx_array arr) { return mlx_array_size_(arr); } + static inline size_t mlx_array_nbytes(const mlx_array arr) { return mlx_array_nbytes_(arr); } + static inline size_t mlx_array_ndim(const mlx_array arr) { return mlx_array_ndim_(arr); } + static inline const int * mlx_array_shape(const mlx_array arr) { return mlx_array_shape_(arr); } + static inline const size_t * mlx_array_strides(const mlx_array arr) { return mlx_array_strides_(arr); } + static inline int mlx_array_dim(const mlx_array arr, int dim) { return mlx_array_dim_(arr, dim); } + static inline mlx_dtype mlx_array_dtype(const mlx_array arr) { return mlx_array_dtype_(arr); } + static inline int mlx_array_eval(mlx_array arr) { return mlx_array_eval_(arr); } + static inline int mlx_array_item_bool(bool* res, const mlx_array arr) { return mlx_array_item_bool_(res, arr); } + static inline int mlx_array_item_uint8(uint8_t* res, const mlx_array arr) { return mlx_array_item_uint8_(res, arr); } + static inline int mlx_array_item_uint16(uint16_t* res, const mlx_array arr) { return mlx_array_item_uint16_(res, arr); } + static inline int mlx_array_item_uint32(uint32_t* res, const mlx_array arr) { return mlx_array_item_uint32_(res, arr); } + static inline int mlx_array_item_uint64(uint64_t* res, const mlx_array arr) { return mlx_array_item_uint64_(res, arr); } + static inline int mlx_array_item_int8(int8_t* res, const mlx_array arr) { return mlx_array_item_int8_(res, arr); } + static inline int mlx_array_item_int16(int16_t* res, const mlx_array arr) { return mlx_array_item_int16_(res, arr); } + static inline int mlx_array_item_int32(int32_t* res, const mlx_array arr) { return mlx_array_item_int32_(res, arr); } + static inline int mlx_array_item_int64(int64_t* res, const mlx_array arr) { return mlx_array_item_int64_(res, arr); } + static inline int mlx_array_item_float32(float* res, const mlx_array arr) { return mlx_array_item_float32_(res, arr); } + static inline int mlx_array_item_float64(double* res, const mlx_array arr) { return mlx_array_item_float64_(res, arr); } -static inline int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr) { + +static inline int mlx_array_item_complex64(float _Complex* res, const mlx_array arr) { return mlx_array_item_complex64_(res, arr); } + static inline int mlx_array_item_float16(float16_t* res, const mlx_array arr) { return mlx_array_item_float16_(res, arr); } + static inline int mlx_array_item_bfloat16(bfloat16_t* res, const mlx_array arr) { return mlx_array_item_bfloat16_(res, arr); } + static inline const bool * mlx_array_data_bool(const mlx_array arr) { return mlx_array_data_bool_(arr); } + static inline const uint8_t * mlx_array_data_uint8(const mlx_array arr) { return mlx_array_data_uint8_(arr); } + static inline const uint16_t * mlx_array_data_uint16(const mlx_array arr) { return mlx_array_data_uint16_(arr); } + static inline const uint32_t * mlx_array_data_uint32(const mlx_array arr) { return mlx_array_data_uint32_(arr); } + static inline const uint64_t * mlx_array_data_uint64(const mlx_array arr) { return mlx_array_data_uint64_(arr); } + static inline const int8_t * mlx_array_data_int8(const mlx_array arr) { return mlx_array_data_int8_(arr); } + static inline const int16_t * mlx_array_data_int16(const mlx_array arr) { return mlx_array_data_int16_(arr); } + static inline const int32_t * mlx_array_data_int32(const mlx_array arr) { return mlx_array_data_int32_(arr); } + static inline const int64_t * mlx_array_data_int64(const mlx_array arr) { return mlx_array_data_int64_(arr); } + static inline const float * mlx_array_data_float32(const mlx_array arr) { return mlx_array_data_float32_(arr); } + static inline const double * mlx_array_data_float64(const mlx_array arr) { return mlx_array_data_float64_(arr); } -static inline const mlx_complex64_t * mlx_array_data_complex64(const mlx_array arr) { + +static inline const float _Complex * mlx_array_data_complex64(const mlx_array arr) { return mlx_array_data_complex64_(arr); } + static inline const float16_t * mlx_array_data_float16(const mlx_array arr) { return mlx_array_data_float16_(arr); } + static inline const bfloat16_t * mlx_array_data_bfloat16(const mlx_array arr) { return mlx_array_data_bfloat16_(arr); } + static inline int _mlx_array_is_available(bool* res, const mlx_array arr) { return _mlx_array_is_available_(res, arr); } + static inline int _mlx_array_wait(const mlx_array arr) { return _mlx_array_wait_(arr); } + static inline int _mlx_array_is_contiguous(bool* res, const mlx_array arr) { return _mlx_array_is_contiguous_(res, arr); } + static inline int _mlx_array_is_row_contiguous(bool* res, const mlx_array arr) { return _mlx_array_is_row_contiguous_(res, arr); } + static inline int _mlx_array_is_col_contiguous(bool* res, const mlx_array arr) { return _mlx_array_is_col_contiguous_(res, arr); } + static inline mlx_closure mlx_closure_new(void) { return mlx_closure_new_(); } + static inline int mlx_closure_free(mlx_closure cls) { return mlx_closure_free_(cls); } + static inline mlx_closure mlx_closure_new_func( int (*fun)(mlx_vector_array*, const mlx_vector_array)) { return mlx_closure_new_func_(fun); } + static inline mlx_closure mlx_closure_new_func_payload( int (*fun)(mlx_vector_array*, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)) { return mlx_closure_new_func_payload_(fun, payload, dtor); } + static inline int mlx_closure_set(mlx_closure* cls, const mlx_closure src) { return mlx_closure_set_(cls, src); } + static inline int mlx_closure_apply( mlx_vector_array* res, mlx_closure cls, const mlx_vector_array input) { return mlx_closure_apply_(res, cls, input); } + static inline mlx_closure mlx_closure_new_unary(int (*fun)(mlx_array*, const mlx_array)) { return mlx_closure_new_unary_(fun); } + static inline mlx_closure_kwargs mlx_closure_kwargs_new(void) { return mlx_closure_kwargs_new_(); } + static inline int mlx_closure_kwargs_free(mlx_closure_kwargs cls) { return mlx_closure_kwargs_free_(cls); } -static inline mlx_closure_kwargs mlx_closure_kwargs_new_func( - int (*fun)( - mlx_vector_array*, - const mlx_vector_array, - const mlx_map_string_to_array)) { + +static inline mlx_closure_kwargs mlx_closure_kwargs_new_func(int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_map_string_to_array)) { return mlx_closure_kwargs_new_func_(fun); } + static inline mlx_closure_kwargs mlx_closure_kwargs_new_func_payload( int (*fun)( mlx_vector_array*, @@ -3697,11 +3611,13 @@ static inline mlx_closure_kwargs mlx_closure_kwargs_new_func_payload( void (*dtor)(void*)) { return mlx_closure_kwargs_new_func_payload_(fun, payload, dtor); } + static inline int mlx_closure_kwargs_set( mlx_closure_kwargs* cls, const mlx_closure_kwargs src) { return mlx_closure_kwargs_set_(cls, src); } + static inline int mlx_closure_kwargs_apply( mlx_vector_array* res, mlx_closure_kwargs cls, @@ -3709,16 +3625,20 @@ static inline int mlx_closure_kwargs_apply( const mlx_map_string_to_array input_1) { return mlx_closure_kwargs_apply_(res, cls, input_0, input_1); } + static inline mlx_closure_value_and_grad mlx_closure_value_and_grad_new(void) { return mlx_closure_value_and_grad_new_(); } + static inline int mlx_closure_value_and_grad_free(mlx_closure_value_and_grad cls) { return mlx_closure_value_and_grad_free_(cls); } + static inline mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func( int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array)) { return mlx_closure_value_and_grad_new_func_(fun); } + static inline mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func_payload( int (*fun)( mlx_vector_array*, @@ -3729,11 +3649,13 @@ static inline mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func_pay void (*dtor)(void*)) { return mlx_closure_value_and_grad_new_func_payload_(fun, payload, dtor); } + static inline int mlx_closure_value_and_grad_set( mlx_closure_value_and_grad* cls, const mlx_closure_value_and_grad src) { return mlx_closure_value_and_grad_set_(cls, src); } + static inline int mlx_closure_value_and_grad_apply( mlx_vector_array* res_0, mlx_vector_array* res_1, @@ -3741,20 +3663,23 @@ static inline int mlx_closure_value_and_grad_apply( const mlx_vector_array input) { return mlx_closure_value_and_grad_apply_(res_0, res_1, cls, input); } + static inline mlx_closure_custom mlx_closure_custom_new(void) { return mlx_closure_custom_new_(); } + static inline int mlx_closure_custom_free(mlx_closure_custom cls) { return mlx_closure_custom_free_(cls); } -static inline mlx_closure_custom mlx_closure_custom_new_func( - int (*fun)( - mlx_vector_array*, - const mlx_vector_array, - const mlx_vector_array, - const mlx_vector_array)) { + +static inline mlx_closure_custom mlx_closure_custom_new_func(int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const mlx_vector_array)) { return mlx_closure_custom_new_func_(fun); } + static inline mlx_closure_custom mlx_closure_custom_new_func_payload( int (*fun)( mlx_vector_array*, @@ -3766,11 +3691,13 @@ static inline mlx_closure_custom mlx_closure_custom_new_func_payload( void (*dtor)(void*)) { return mlx_closure_custom_new_func_payload_(fun, payload, dtor); } + static inline int mlx_closure_custom_set( mlx_closure_custom* cls, const mlx_closure_custom src) { return mlx_closure_custom_set_(cls, src); } + static inline int mlx_closure_custom_apply( mlx_vector_array* res, mlx_closure_custom cls, @@ -3779,21 +3706,24 @@ static inline int mlx_closure_custom_apply( const mlx_vector_array input_2) { return mlx_closure_custom_apply_(res, cls, input_0, input_1, input_2); } + static inline mlx_closure_custom_jvp mlx_closure_custom_jvp_new(void) { return mlx_closure_custom_jvp_new_(); } + static inline int mlx_closure_custom_jvp_free(mlx_closure_custom_jvp cls) { return mlx_closure_custom_jvp_free_(cls); } -static inline mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func( - int (*fun)( - mlx_vector_array*, - const mlx_vector_array, - const mlx_vector_array, - const int*, - size_t _num)) { + +static inline mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func(int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const int*, + size_t _num)) { return mlx_closure_custom_jvp_new_func_(fun); } + static inline mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func_payload( int (*fun)( mlx_vector_array*, @@ -3806,11 +3736,13 @@ static inline mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func_payload( void (*dtor)(void*)) { return mlx_closure_custom_jvp_new_func_payload_(fun, payload, dtor); } + static inline int mlx_closure_custom_jvp_set( mlx_closure_custom_jvp* cls, const mlx_closure_custom_jvp src) { return mlx_closure_custom_jvp_set_(cls, src); } + static inline int mlx_closure_custom_jvp_apply( mlx_vector_array* res, mlx_closure_custom_jvp cls, @@ -3820,21 +3752,24 @@ static inline int mlx_closure_custom_jvp_apply( size_t input_2_num) { return mlx_closure_custom_jvp_apply_(res, cls, input_0, input_1, input_2, input_2_num); } + static inline mlx_closure_custom_vmap mlx_closure_custom_vmap_new(void) { return mlx_closure_custom_vmap_new_(); } + static inline int mlx_closure_custom_vmap_free(mlx_closure_custom_vmap cls) { return mlx_closure_custom_vmap_free_(cls); } -static inline mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func( - int (*fun)( - mlx_vector_array*, - mlx_vector_int*, - const mlx_vector_array, - const int*, - size_t _num)) { + +static inline mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func(int (*fun)( + mlx_vector_array*, + mlx_vector_int*, + const mlx_vector_array, + const int*, + size_t _num)) { return mlx_closure_custom_vmap_new_func_(fun); } + static inline mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func_payload( int (*fun)( mlx_vector_array*, @@ -3847,11 +3782,13 @@ static inline mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func_payload( void (*dtor)(void*)) { return mlx_closure_custom_vmap_new_func_payload_(fun, payload, dtor); } + static inline int mlx_closure_custom_vmap_set( mlx_closure_custom_vmap* cls, const mlx_closure_custom_vmap src) { return mlx_closure_custom_vmap_set_(cls, src); } + static inline int mlx_closure_custom_vmap_apply( mlx_vector_array* res_0, mlx_vector_int* res_1, @@ -3861,9 +3798,11 @@ static inline int mlx_closure_custom_vmap_apply( size_t input_1_num) { return mlx_closure_custom_vmap_apply_(res_0, res_1, cls, input_0, input_1, input_1_num); } + static inline int mlx_compile(mlx_closure* res, const mlx_closure fun, bool shapeless) { return mlx_compile_(res, fun, shapeless); } + static inline int mlx_detail_compile( mlx_closure* res, const mlx_closure fun, @@ -3873,96 +3812,87 @@ static inline int mlx_detail_compile( size_t constants_num) { return mlx_detail_compile_(res, fun, fun_id, shapeless, constants, constants_num); } + static inline int mlx_detail_compile_clear_cache(void) { return mlx_detail_compile_clear_cache_(); } + static inline int mlx_detail_compile_erase(uintptr_t fun_id) { return mlx_detail_compile_erase_(fun_id); } + static inline int mlx_disable_compile(void) { return mlx_disable_compile_(); } + static inline int mlx_enable_compile(void) { return mlx_enable_compile_(); } + static inline int mlx_set_compile_mode(mlx_compile_mode mode) { return mlx_set_compile_mode_(mode); } -static inline int mlx_cuda_is_available(bool* res) { - return mlx_cuda_is_available_(res); -} + static inline mlx_device mlx_device_new(void) { return mlx_device_new_(); } + static inline mlx_device mlx_device_new_type(mlx_device_type type, int index) { return mlx_device_new_type_(type, index); } + static inline int mlx_device_free(mlx_device dev) { return mlx_device_free_(dev); } + static inline int mlx_device_set(mlx_device* dev, const mlx_device src) { return mlx_device_set_(dev, src); } + static inline int mlx_device_tostring(mlx_string* str, mlx_device dev) { return mlx_device_tostring_(str, dev); } + static inline bool mlx_device_equal(mlx_device lhs, mlx_device rhs) { return mlx_device_equal_(lhs, rhs); } + static inline int mlx_device_get_index(int* index, mlx_device dev) { return mlx_device_get_index_(index, dev); } + static inline int mlx_device_get_type(mlx_device_type* type, mlx_device dev) { return mlx_device_get_type_(type, dev); } + static inline int mlx_get_default_device(mlx_device* dev) { return mlx_get_default_device_(dev); } + static inline int mlx_set_default_device(mlx_device dev) { return mlx_set_default_device_(dev); } -static inline int mlx_device_is_available(bool* avail, mlx_device dev) { - return mlx_device_is_available_(avail, dev); + +static inline int mlx_distributed_group_rank(mlx_distributed_group group) { + return mlx_distributed_group_rank_(group); } -static inline int mlx_device_count(int* count, mlx_device_type type) { - return mlx_device_count_(count, type); + +static inline int mlx_distributed_group_size(mlx_distributed_group group) { + return mlx_distributed_group_size_(group); } -static inline mlx_device_info mlx_device_info_new(void) { - return mlx_device_info_new_(); + +static inline mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key) { + return mlx_distributed_group_split_(group, color, key); } -static inline int mlx_device_info_get(mlx_device_info* info, mlx_device dev) { - return mlx_device_info_get_(info, dev); + +static inline bool mlx_distributed_is_available(void) { + return mlx_distributed_is_available_(); } -static inline int mlx_device_info_free(mlx_device_info info) { - return mlx_device_info_free_(info); -} -static inline int mlx_device_info_has_key( - bool* exists, - mlx_device_info info, - const char* key) { - return mlx_device_info_has_key_(exists, info, key); -} -static inline int mlx_device_info_is_string( - bool* is_string, - mlx_device_info info, - const char* key) { - return mlx_device_info_is_string_(is_string, info, key); -} -static inline int mlx_device_info_get_string( - const char** value, - mlx_device_info info, - const char* key) { - return mlx_device_info_get_string_(value, info, key); -} -static inline int mlx_device_info_get_size( - size_t* value, - mlx_device_info info, - const char* key) { - return mlx_device_info_get_size_(value, info, key); -} -static inline int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info) { - return mlx_device_info_get_keys_(keys, info); + +static inline mlx_distributed_group mlx_distributed_init(bool strict) { + return mlx_distributed_init_(strict); } + static inline int mlx_distributed_all_gather( mlx_array* res, const mlx_array x, @@ -3970,6 +3900,7 @@ static inline int mlx_distributed_all_gather( const mlx_stream S) { return mlx_distributed_all_gather_(res, x, group, S); } + static inline int mlx_distributed_all_max( mlx_array* res, const mlx_array x, @@ -3977,6 +3908,7 @@ static inline int mlx_distributed_all_max( const mlx_stream s) { return mlx_distributed_all_max_(res, x, group, s); } + static inline int mlx_distributed_all_min( mlx_array* res, const mlx_array x, @@ -3984,6 +3916,7 @@ static inline int mlx_distributed_all_min( const mlx_stream s) { return mlx_distributed_all_min_(res, x, group, s); } + static inline int mlx_distributed_all_sum( mlx_array* res, const mlx_array x, @@ -3991,6 +3924,7 @@ static inline int mlx_distributed_all_sum( const mlx_stream s) { return mlx_distributed_all_sum_(res, x, group, s); } + static inline int mlx_distributed_recv( mlx_array* res, const int* shape, @@ -4001,6 +3935,7 @@ static inline int mlx_distributed_recv( const mlx_stream s) { return mlx_distributed_recv_(res, shape, shape_num, dtype, src, group, s); } + static inline int mlx_distributed_recv_like( mlx_array* res, const mlx_array x, @@ -4009,6 +3944,7 @@ static inline int mlx_distributed_recv_like( const mlx_stream s) { return mlx_distributed_recv_like_(res, x, src, group, s); } + static inline int mlx_distributed_send( mlx_array* res, const mlx_array x, @@ -4017,6 +3953,7 @@ static inline int mlx_distributed_send( const mlx_stream s) { return mlx_distributed_send_(res, x, dst, group, s); } + static inline int mlx_distributed_sum_scatter( mlx_array* res, const mlx_array x, @@ -4024,30 +3961,16 @@ static inline int mlx_distributed_sum_scatter( const mlx_stream s) { return mlx_distributed_sum_scatter_(res, x, group, s); } -static inline int mlx_distributed_group_rank(mlx_distributed_group group) { - return mlx_distributed_group_rank_(group); -} -static inline int mlx_distributed_group_size(mlx_distributed_group group) { - return mlx_distributed_group_size_(group); -} -static inline mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key) { - return mlx_distributed_group_split_(group, color, key); -} -static inline bool mlx_distributed_is_available(void) { - return mlx_distributed_is_available_(); -} -static inline mlx_distributed_group mlx_distributed_init(bool strict) { - return mlx_distributed_init_(strict); -} + static inline void mlx_set_error_handler( mlx_error_handler_func handler, void* data, void (*dtor)(void*)) { - return mlx_set_error_handler_(handler, data, dtor); -} -static inline void _mlx_error(const char* file, const int line, const char* fmt, ...) { - return _mlx_error_(file, line, fmt); + mlx_set_error_handler_(handler, data, dtor); } + +#define _mlx_error(file, line, fmt, ...) _mlx_error_(file, line, fmt, __VA_ARGS__) + static inline int mlx_export_function( const char* file, const mlx_closure fun, @@ -4055,6 +3978,7 @@ static inline int mlx_export_function( bool shapeless) { return mlx_export_function_(file, fun, args, shapeless); } + static inline int mlx_export_function_kwargs( const char* file, const mlx_closure_kwargs fun, @@ -4063,38 +3987,46 @@ static inline int mlx_export_function_kwargs( bool shapeless) { return mlx_export_function_kwargs_(file, fun, args, kwargs, shapeless); } + static inline mlx_function_exporter mlx_function_exporter_new( const char* file, const mlx_closure fun, bool shapeless) { return mlx_function_exporter_new_(file, fun, shapeless); } + static inline int mlx_function_exporter_free(mlx_function_exporter xfunc) { return mlx_function_exporter_free_(xfunc); } + static inline int mlx_function_exporter_apply( const mlx_function_exporter xfunc, const mlx_vector_array args) { return mlx_function_exporter_apply_(xfunc, args); } + static inline int mlx_function_exporter_apply_kwargs( const mlx_function_exporter xfunc, const mlx_vector_array args, const mlx_map_string_to_array kwargs) { return mlx_function_exporter_apply_kwargs_(xfunc, args, kwargs); } + static inline mlx_imported_function mlx_imported_function_new(const char* file) { return mlx_imported_function_new_(file); } + static inline int mlx_imported_function_free(mlx_imported_function xfunc) { return mlx_imported_function_free_(xfunc); } + static inline int mlx_imported_function_apply( mlx_vector_array* res, const mlx_imported_function xfunc, const mlx_vector_array args) { return mlx_imported_function_apply_(res, xfunc, args); } + static inline int mlx_imported_function_apply_kwargs( mlx_vector_array* res, const mlx_imported_function xfunc, @@ -4102,12 +4034,15 @@ static inline int mlx_imported_function_apply_kwargs( const mlx_map_string_to_array kwargs) { return mlx_imported_function_apply_kwargs_(res, xfunc, args, kwargs); } + static inline mlx_fast_cuda_kernel_config mlx_fast_cuda_kernel_config_new(void) { return mlx_fast_cuda_kernel_config_new_(); } + static inline void mlx_fast_cuda_kernel_config_free(mlx_fast_cuda_kernel_config cls) { - return mlx_fast_cuda_kernel_config_free_(cls); + mlx_fast_cuda_kernel_config_free_(cls); } + static inline int mlx_fast_cuda_kernel_config_add_output_arg( mlx_fast_cuda_kernel_config cls, const int* shape, @@ -4115,6 +4050,7 @@ static inline int mlx_fast_cuda_kernel_config_add_output_arg( mlx_dtype dtype) { return mlx_fast_cuda_kernel_config_add_output_arg_(cls, shape, size, dtype); } + static inline int mlx_fast_cuda_kernel_config_set_grid( mlx_fast_cuda_kernel_config cls, int grid1, @@ -4122,6 +4058,7 @@ static inline int mlx_fast_cuda_kernel_config_set_grid( int grid3) { return mlx_fast_cuda_kernel_config_set_grid_(cls, grid1, grid2, grid3); } + static inline int mlx_fast_cuda_kernel_config_set_thread_group( mlx_fast_cuda_kernel_config cls, int thread1, @@ -4129,34 +4066,40 @@ static inline int mlx_fast_cuda_kernel_config_set_thread_group( int thread3) { return mlx_fast_cuda_kernel_config_set_thread_group_(cls, thread1, thread2, thread3); } + static inline int mlx_fast_cuda_kernel_config_set_init_value( mlx_fast_cuda_kernel_config cls, float value) { return mlx_fast_cuda_kernel_config_set_init_value_(cls, value); } + static inline int mlx_fast_cuda_kernel_config_set_verbose( mlx_fast_cuda_kernel_config cls, bool verbose) { return mlx_fast_cuda_kernel_config_set_verbose_(cls, verbose); } + static inline int mlx_fast_cuda_kernel_config_add_template_arg_dtype( mlx_fast_cuda_kernel_config cls, const char* name, mlx_dtype dtype) { return mlx_fast_cuda_kernel_config_add_template_arg_dtype_(cls, name, dtype); } + static inline int mlx_fast_cuda_kernel_config_add_template_arg_int( mlx_fast_cuda_kernel_config cls, const char* name, int value) { return mlx_fast_cuda_kernel_config_add_template_arg_int_(cls, name, value); } + static inline int mlx_fast_cuda_kernel_config_add_template_arg_bool( mlx_fast_cuda_kernel_config cls, const char* name, bool value) { return mlx_fast_cuda_kernel_config_add_template_arg_bool_(cls, name, value); } + static inline mlx_fast_cuda_kernel mlx_fast_cuda_kernel_new( const char* name, const mlx_vector_string input_names, @@ -4167,9 +4110,11 @@ static inline mlx_fast_cuda_kernel mlx_fast_cuda_kernel_new( int shared_memory) { return mlx_fast_cuda_kernel_new_(name, input_names, output_names, source, header, ensure_row_contiguous, shared_memory); } + static inline void mlx_fast_cuda_kernel_free(mlx_fast_cuda_kernel cls) { - return mlx_fast_cuda_kernel_free_(cls); + mlx_fast_cuda_kernel_free_(cls); } + static inline int mlx_fast_cuda_kernel_apply( mlx_vector_array* outputs, mlx_fast_cuda_kernel cls, @@ -4178,6 +4123,7 @@ static inline int mlx_fast_cuda_kernel_apply( const mlx_stream stream) { return mlx_fast_cuda_kernel_apply_(outputs, cls, inputs, config, stream); } + static inline int mlx_fast_layer_norm( mlx_array* res, const mlx_array x, @@ -4187,12 +4133,15 @@ static inline int mlx_fast_layer_norm( const mlx_stream s) { return mlx_fast_layer_norm_(res, x, weight, bias, eps, s); } + static inline mlx_fast_metal_kernel_config mlx_fast_metal_kernel_config_new(void) { return mlx_fast_metal_kernel_config_new_(); } + static inline void mlx_fast_metal_kernel_config_free(mlx_fast_metal_kernel_config cls) { - return mlx_fast_metal_kernel_config_free_(cls); + mlx_fast_metal_kernel_config_free_(cls); } + static inline int mlx_fast_metal_kernel_config_add_output_arg( mlx_fast_metal_kernel_config cls, const int* shape, @@ -4200,6 +4149,7 @@ static inline int mlx_fast_metal_kernel_config_add_output_arg( mlx_dtype dtype) { return mlx_fast_metal_kernel_config_add_output_arg_(cls, shape, size, dtype); } + static inline int mlx_fast_metal_kernel_config_set_grid( mlx_fast_metal_kernel_config cls, int grid1, @@ -4207,6 +4157,7 @@ static inline int mlx_fast_metal_kernel_config_set_grid( int grid3) { return mlx_fast_metal_kernel_config_set_grid_(cls, grid1, grid2, grid3); } + static inline int mlx_fast_metal_kernel_config_set_thread_group( mlx_fast_metal_kernel_config cls, int thread1, @@ -4214,34 +4165,40 @@ static inline int mlx_fast_metal_kernel_config_set_thread_group( int thread3) { return mlx_fast_metal_kernel_config_set_thread_group_(cls, thread1, thread2, thread3); } + static inline int mlx_fast_metal_kernel_config_set_init_value( mlx_fast_metal_kernel_config cls, float value) { return mlx_fast_metal_kernel_config_set_init_value_(cls, value); } + static inline int mlx_fast_metal_kernel_config_set_verbose( mlx_fast_metal_kernel_config cls, bool verbose) { return mlx_fast_metal_kernel_config_set_verbose_(cls, verbose); } + static inline int mlx_fast_metal_kernel_config_add_template_arg_dtype( mlx_fast_metal_kernel_config cls, const char* name, mlx_dtype dtype) { return mlx_fast_metal_kernel_config_add_template_arg_dtype_(cls, name, dtype); } + static inline int mlx_fast_metal_kernel_config_add_template_arg_int( mlx_fast_metal_kernel_config cls, const char* name, int value) { return mlx_fast_metal_kernel_config_add_template_arg_int_(cls, name, value); } + static inline int mlx_fast_metal_kernel_config_add_template_arg_bool( mlx_fast_metal_kernel_config cls, const char* name, bool value) { return mlx_fast_metal_kernel_config_add_template_arg_bool_(cls, name, value); } + static inline mlx_fast_metal_kernel mlx_fast_metal_kernel_new( const char* name, const mlx_vector_string input_names, @@ -4252,9 +4209,11 @@ static inline mlx_fast_metal_kernel mlx_fast_metal_kernel_new( bool atomic_outputs) { return mlx_fast_metal_kernel_new_(name, input_names, output_names, source, header, ensure_row_contiguous, atomic_outputs); } + static inline void mlx_fast_metal_kernel_free(mlx_fast_metal_kernel cls) { - return mlx_fast_metal_kernel_free_(cls); + mlx_fast_metal_kernel_free_(cls); } + static inline int mlx_fast_metal_kernel_apply( mlx_vector_array* outputs, mlx_fast_metal_kernel cls, @@ -4263,6 +4222,7 @@ static inline int mlx_fast_metal_kernel_apply( const mlx_stream stream) { return mlx_fast_metal_kernel_apply_(outputs, cls, inputs, config, stream); } + static inline int mlx_fast_rms_norm( mlx_array* res, const mlx_array x, @@ -4271,6 +4231,7 @@ static inline int mlx_fast_rms_norm( const mlx_stream s) { return mlx_fast_rms_norm_(res, x, weight, eps, s); } + static inline int mlx_fast_rope( mlx_array* res, const mlx_array x, @@ -4283,18 +4244,7 @@ static inline int mlx_fast_rope( const mlx_stream s) { return mlx_fast_rope_(res, x, dims, traditional, base, scale, offset, freqs, s); } -static inline int mlx_fast_rope_dynamic( - mlx_array* res, - const mlx_array x, - int dims, - bool traditional, - mlx_optional_float base, - float scale, - const mlx_array offset, - const mlx_array freqs /* may be null */, - const mlx_stream s) { - return mlx_fast_rope_dynamic_(res, x, dims, traditional, base, scale, offset, freqs, s); -} + static inline int mlx_fast_scaled_dot_product_attention( mlx_array* res, const mlx_array queries, @@ -4307,6 +4257,7 @@ static inline int mlx_fast_scaled_dot_product_attention( const mlx_stream s) { return mlx_fast_scaled_dot_product_attention_(res, queries, keys, values, scale, mask_mode, mask_arr, sinks, s); } + static inline int mlx_fft_fft( mlx_array* res, const mlx_array a, @@ -4315,6 +4266,7 @@ static inline int mlx_fft_fft( const mlx_stream s) { return mlx_fft_fft_(res, a, n, axis, s); } + static inline int mlx_fft_fft2( mlx_array* res, const mlx_array a, @@ -4325,6 +4277,7 @@ static inline int mlx_fft_fft2( const mlx_stream s) { return mlx_fft_fft2_(res, a, n, n_num, axes, axes_num, s); } + static inline int mlx_fft_fftn( mlx_array* res, const mlx_array a, @@ -4335,6 +4288,7 @@ static inline int mlx_fft_fftn( const mlx_stream s) { return mlx_fft_fftn_(res, a, n, n_num, axes, axes_num, s); } + static inline int mlx_fft_fftshift( mlx_array* res, const mlx_array a, @@ -4343,6 +4297,7 @@ static inline int mlx_fft_fftshift( const mlx_stream s) { return mlx_fft_fftshift_(res, a, axes, axes_num, s); } + static inline int mlx_fft_ifft( mlx_array* res, const mlx_array a, @@ -4351,6 +4306,7 @@ static inline int mlx_fft_ifft( const mlx_stream s) { return mlx_fft_ifft_(res, a, n, axis, s); } + static inline int mlx_fft_ifft2( mlx_array* res, const mlx_array a, @@ -4361,6 +4317,7 @@ static inline int mlx_fft_ifft2( const mlx_stream s) { return mlx_fft_ifft2_(res, a, n, n_num, axes, axes_num, s); } + static inline int mlx_fft_ifftn( mlx_array* res, const mlx_array a, @@ -4371,6 +4328,7 @@ static inline int mlx_fft_ifftn( const mlx_stream s) { return mlx_fft_ifftn_(res, a, n, n_num, axes, axes_num, s); } + static inline int mlx_fft_ifftshift( mlx_array* res, const mlx_array a, @@ -4379,6 +4337,7 @@ static inline int mlx_fft_ifftshift( const mlx_stream s) { return mlx_fft_ifftshift_(res, a, axes, axes_num, s); } + static inline int mlx_fft_irfft( mlx_array* res, const mlx_array a, @@ -4387,6 +4346,7 @@ static inline int mlx_fft_irfft( const mlx_stream s) { return mlx_fft_irfft_(res, a, n, axis, s); } + static inline int mlx_fft_irfft2( mlx_array* res, const mlx_array a, @@ -4397,6 +4357,7 @@ static inline int mlx_fft_irfft2( const mlx_stream s) { return mlx_fft_irfft2_(res, a, n, n_num, axes, axes_num, s); } + static inline int mlx_fft_irfftn( mlx_array* res, const mlx_array a, @@ -4407,6 +4368,7 @@ static inline int mlx_fft_irfftn( const mlx_stream s) { return mlx_fft_irfftn_(res, a, n, n_num, axes, axes_num, s); } + static inline int mlx_fft_rfft( mlx_array* res, const mlx_array a, @@ -4415,6 +4377,7 @@ static inline int mlx_fft_rfft( const mlx_stream s) { return mlx_fft_rfft_(res, a, n, axis, s); } + static inline int mlx_fft_rfft2( mlx_array* res, const mlx_array a, @@ -4425,6 +4388,7 @@ static inline int mlx_fft_rfft2( const mlx_stream s) { return mlx_fft_rfft2_(res, a, n, n_num, axes, axes_num, s); } + static inline int mlx_fft_rfftn( mlx_array* res, const mlx_array a, @@ -4435,15 +4399,50 @@ static inline int mlx_fft_rfftn( const mlx_stream s) { return mlx_fft_rfftn_(res, a, n, n_num, axes, axes_num, s); } + +static inline mlx_io_reader mlx_io_reader_new(void* desc, mlx_io_vtable vtable) { + return mlx_io_reader_new_(desc, vtable); +} + +static inline int mlx_io_reader_descriptor(void** desc_, mlx_io_reader io) { + return mlx_io_reader_descriptor_(desc_, io); +} + +static inline int mlx_io_reader_tostring(mlx_string* str_, mlx_io_reader io) { + return mlx_io_reader_tostring_(str_, io); +} + +static inline int mlx_io_reader_free(mlx_io_reader io) { + return mlx_io_reader_free_(io); +} + +static inline mlx_io_writer mlx_io_writer_new(void* desc, mlx_io_vtable vtable) { + return mlx_io_writer_new_(desc, vtable); +} + +static inline int mlx_io_writer_descriptor(void** desc_, mlx_io_writer io) { + return mlx_io_writer_descriptor_(desc_, io); +} + +static inline int mlx_io_writer_tostring(mlx_string* str_, mlx_io_writer io) { + return mlx_io_writer_tostring_(str_, io); +} + +static inline int mlx_io_writer_free(mlx_io_writer io) { + return mlx_io_writer_free_(io); +} + static inline int mlx_load_reader( mlx_array* res, mlx_io_reader in_stream, const mlx_stream s) { return mlx_load_reader_(res, in_stream, s); } + static inline int mlx_load(mlx_array* res, const char* file, const mlx_stream s) { return mlx_load_(res, file, s); } + static inline int mlx_load_safetensors_reader( mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, @@ -4451,6 +4450,7 @@ static inline int mlx_load_safetensors_reader( const mlx_stream s) { return mlx_load_safetensors_reader_(res_0, res_1, in_stream, s); } + static inline int mlx_load_safetensors( mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, @@ -4458,48 +4458,29 @@ static inline int mlx_load_safetensors( const mlx_stream s) { return mlx_load_safetensors_(res_0, res_1, file, s); } + static inline int mlx_save_writer(mlx_io_writer out_stream, const mlx_array a) { return mlx_save_writer_(out_stream, a); } + static inline int mlx_save(const char* file, const mlx_array a) { return mlx_save_(file, a); } + static inline int mlx_save_safetensors_writer( mlx_io_writer in_stream, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata) { return mlx_save_safetensors_writer_(in_stream, param, metadata); } + static inline int mlx_save_safetensors( const char* file, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata) { return mlx_save_safetensors_(file, param, metadata); } -static inline mlx_io_reader mlx_io_reader_new(void* desc, mlx_io_vtable vtable) { - return mlx_io_reader_new_(desc, vtable); -} -static inline int mlx_io_reader_descriptor(void** desc_, mlx_io_reader io) { - return mlx_io_reader_descriptor_(desc_, io); -} -static inline int mlx_io_reader_tostring(mlx_string* str_, mlx_io_reader io) { - return mlx_io_reader_tostring_(str_, io); -} -static inline int mlx_io_reader_free(mlx_io_reader io) { - return mlx_io_reader_free_(io); -} -static inline mlx_io_writer mlx_io_writer_new(void* desc, mlx_io_vtable vtable) { - return mlx_io_writer_new_(desc, vtable); -} -static inline int mlx_io_writer_descriptor(void** desc_, mlx_io_writer io) { - return mlx_io_writer_descriptor_(desc_, io); -} -static inline int mlx_io_writer_tostring(mlx_string* str_, mlx_io_writer io) { - return mlx_io_writer_tostring_(str_, io); -} -static inline int mlx_io_writer_free(mlx_io_writer io) { - return mlx_io_writer_free_(io); -} + static inline int mlx_linalg_cholesky( mlx_array* res, const mlx_array a, @@ -4507,6 +4488,7 @@ static inline int mlx_linalg_cholesky( const mlx_stream s) { return mlx_linalg_cholesky_(res, a, upper, s); } + static inline int mlx_linalg_cholesky_inv( mlx_array* res, const mlx_array a, @@ -4514,6 +4496,7 @@ static inline int mlx_linalg_cholesky_inv( const mlx_stream s) { return mlx_linalg_cholesky_inv_(res, a, upper, s); } + static inline int mlx_linalg_cross( mlx_array* res, const mlx_array a, @@ -4522,6 +4505,7 @@ static inline int mlx_linalg_cross( const mlx_stream s) { return mlx_linalg_cross_(res, a, b, axis, s); } + static inline int mlx_linalg_eig( mlx_array* res_0, mlx_array* res_1, @@ -4529,6 +4513,7 @@ static inline int mlx_linalg_eig( const mlx_stream s) { return mlx_linalg_eig_(res_0, res_1, a, s); } + static inline int mlx_linalg_eigh( mlx_array* res_0, mlx_array* res_1, @@ -4537,9 +4522,11 @@ static inline int mlx_linalg_eigh( const mlx_stream s) { return mlx_linalg_eigh_(res_0, res_1, a, UPLO, s); } + static inline int mlx_linalg_eigvals(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_linalg_eigvals_(res, a, s); } + static inline int mlx_linalg_eigvalsh( mlx_array* res, const mlx_array a, @@ -4547,12 +4534,15 @@ static inline int mlx_linalg_eigvalsh( const mlx_stream s) { return mlx_linalg_eigvalsh_(res, a, UPLO, s); } + static inline int mlx_linalg_inv(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_linalg_inv_(res, a, s); } + static inline int mlx_linalg_lu(mlx_vector_array* res, const mlx_array a, const mlx_stream s) { return mlx_linalg_lu_(res, a, s); } + static inline int mlx_linalg_lu_factor( mlx_array* res_0, mlx_array* res_1, @@ -4560,6 +4550,7 @@ static inline int mlx_linalg_lu_factor( const mlx_stream s) { return mlx_linalg_lu_factor_(res_0, res_1, a, s); } + static inline int mlx_linalg_norm( mlx_array* res, const mlx_array a, @@ -4570,6 +4561,7 @@ static inline int mlx_linalg_norm( const mlx_stream s) { return mlx_linalg_norm_(res, a, ord, axis, axis_num, keepdims, s); } + static inline int mlx_linalg_norm_matrix( mlx_array* res, const mlx_array a, @@ -4580,6 +4572,7 @@ static inline int mlx_linalg_norm_matrix( const mlx_stream s) { return mlx_linalg_norm_matrix_(res, a, ord, axis, axis_num, keepdims, s); } + static inline int mlx_linalg_norm_l2( mlx_array* res, const mlx_array a, @@ -4589,9 +4582,11 @@ static inline int mlx_linalg_norm_l2( const mlx_stream s) { return mlx_linalg_norm_l2_(res, a, axis, axis_num, keepdims, s); } + static inline int mlx_linalg_pinv(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_linalg_pinv_(res, a, s); } + static inline int mlx_linalg_qr( mlx_array* res_0, mlx_array* res_1, @@ -4599,6 +4594,7 @@ static inline int mlx_linalg_qr( const mlx_stream s) { return mlx_linalg_qr_(res_0, res_1, a, s); } + static inline int mlx_linalg_solve( mlx_array* res, const mlx_array a, @@ -4606,6 +4602,7 @@ static inline int mlx_linalg_solve( const mlx_stream s) { return mlx_linalg_solve_(res, a, b, s); } + static inline int mlx_linalg_solve_triangular( mlx_array* res, const mlx_array a, @@ -4614,6 +4611,7 @@ static inline int mlx_linalg_solve_triangular( const mlx_stream s) { return mlx_linalg_solve_triangular_(res, a, b, upper, s); } + static inline int mlx_linalg_svd( mlx_vector_array* res, const mlx_array a, @@ -4621,6 +4619,7 @@ static inline int mlx_linalg_svd( const mlx_stream s) { return mlx_linalg_svd_(res, a, compute_uv, s); } + static inline int mlx_linalg_tri_inv( mlx_array* res, const mlx_array a, @@ -4628,118 +4627,152 @@ static inline int mlx_linalg_tri_inv( const mlx_stream s) { return mlx_linalg_tri_inv_(res, a, upper, s); } + static inline mlx_map_string_to_array mlx_map_string_to_array_new(void) { return mlx_map_string_to_array_new_(); } + static inline int mlx_map_string_to_array_set( mlx_map_string_to_array* map, const mlx_map_string_to_array src) { return mlx_map_string_to_array_set_(map, src); } + static inline int mlx_map_string_to_array_free(mlx_map_string_to_array map) { return mlx_map_string_to_array_free_(map); } + static inline int mlx_map_string_to_array_insert( mlx_map_string_to_array map, const char* key, const mlx_array value) { return mlx_map_string_to_array_insert_(map, key, value); } + static inline int mlx_map_string_to_array_get( mlx_array* value, const mlx_map_string_to_array map, const char* key) { return mlx_map_string_to_array_get_(value, map, key); } + static inline mlx_map_string_to_array_iterator mlx_map_string_to_array_iterator_new( mlx_map_string_to_array map) { return mlx_map_string_to_array_iterator_new_(map); } + static inline int mlx_map_string_to_array_iterator_free(mlx_map_string_to_array_iterator it) { return mlx_map_string_to_array_iterator_free_(it); } + static inline int mlx_map_string_to_array_iterator_next( const char** key, mlx_array* value, mlx_map_string_to_array_iterator it) { return mlx_map_string_to_array_iterator_next_(key, value, it); } + static inline mlx_map_string_to_string mlx_map_string_to_string_new(void) { return mlx_map_string_to_string_new_(); } + static inline int mlx_map_string_to_string_set( mlx_map_string_to_string* map, const mlx_map_string_to_string src) { return mlx_map_string_to_string_set_(map, src); } + static inline int mlx_map_string_to_string_free(mlx_map_string_to_string map) { return mlx_map_string_to_string_free_(map); } + static inline int mlx_map_string_to_string_insert( mlx_map_string_to_string map, const char* key, const char* value) { return mlx_map_string_to_string_insert_(map, key, value); } + static inline int mlx_map_string_to_string_get( const char** value, const mlx_map_string_to_string map, const char* key) { return mlx_map_string_to_string_get_(value, map, key); } + static inline mlx_map_string_to_string_iterator mlx_map_string_to_string_iterator_new( mlx_map_string_to_string map) { return mlx_map_string_to_string_iterator_new_(map); } + static inline int mlx_map_string_to_string_iterator_free( mlx_map_string_to_string_iterator it) { return mlx_map_string_to_string_iterator_free_(it); } + static inline int mlx_map_string_to_string_iterator_next( const char** key, const char** value, mlx_map_string_to_string_iterator it) { return mlx_map_string_to_string_iterator_next_(key, value, it); } + static inline int mlx_clear_cache(void) { return mlx_clear_cache_(); } + static inline int mlx_get_active_memory(size_t* res) { return mlx_get_active_memory_(res); } + static inline int mlx_get_cache_memory(size_t* res) { return mlx_get_cache_memory_(res); } + static inline int mlx_get_memory_limit(size_t* res) { return mlx_get_memory_limit_(res); } + static inline int mlx_get_peak_memory(size_t* res) { return mlx_get_peak_memory_(res); } + static inline int mlx_reset_peak_memory(void) { return mlx_reset_peak_memory_(); } + static inline int mlx_set_cache_limit(size_t* res, size_t limit) { return mlx_set_cache_limit_(res, limit); } + static inline int mlx_set_memory_limit(size_t* res, size_t limit) { return mlx_set_memory_limit_(res, limit); } + static inline int mlx_set_wired_limit(size_t* res, size_t limit) { return mlx_set_wired_limit_(res, limit); } + +static inline mlx_metal_device_info_t mlx_metal_device_info(void) { + return mlx_metal_device_info_(); +} + static inline int mlx_metal_is_available(bool* res) { return mlx_metal_is_available_(res); } + static inline int mlx_metal_start_capture(const char* path) { return mlx_metal_start_capture_(path); } + static inline int mlx_metal_stop_capture(void) { return mlx_metal_stop_capture_(); } + static inline int mlx_abs(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_abs_(res, a, s); } + static inline int mlx_add( mlx_array* res, const mlx_array a, @@ -4747,6 +4780,7 @@ static inline int mlx_add( const mlx_stream s) { return mlx_add_(res, a, b, s); } + static inline int mlx_addmm( mlx_array* res, const mlx_array c, @@ -4757,6 +4791,7 @@ static inline int mlx_addmm( const mlx_stream s) { return mlx_addmm_(res, c, a, b, alpha, beta, s); } + static inline int mlx_all_axes( mlx_array* res, const mlx_array a, @@ -4766,6 +4801,7 @@ static inline int mlx_all_axes( const mlx_stream s) { return mlx_all_axes_(res, a, axes, axes_num, keepdims, s); } + static inline int mlx_all_axis( mlx_array* res, const mlx_array a, @@ -4774,6 +4810,7 @@ static inline int mlx_all_axis( const mlx_stream s) { return mlx_all_axis_(res, a, axis, keepdims, s); } + static inline int mlx_all( mlx_array* res, const mlx_array a, @@ -4781,6 +4818,7 @@ static inline int mlx_all( const mlx_stream s) { return mlx_all_(res, a, keepdims, s); } + static inline int mlx_allclose( mlx_array* res, const mlx_array a, @@ -4791,6 +4829,7 @@ static inline int mlx_allclose( const mlx_stream s) { return mlx_allclose_(res, a, b, rtol, atol, equal_nan, s); } + static inline int mlx_any_axes( mlx_array* res, const mlx_array a, @@ -4800,6 +4839,7 @@ static inline int mlx_any_axes( const mlx_stream s) { return mlx_any_axes_(res, a, axes, axes_num, keepdims, s); } + static inline int mlx_any_axis( mlx_array* res, const mlx_array a, @@ -4808,6 +4848,7 @@ static inline int mlx_any_axis( const mlx_stream s) { return mlx_any_axis_(res, a, axis, keepdims, s); } + static inline int mlx_any( mlx_array* res, const mlx_array a, @@ -4815,6 +4856,7 @@ static inline int mlx_any( const mlx_stream s) { return mlx_any_(res, a, keepdims, s); } + static inline int mlx_arange( mlx_array* res, double start, @@ -4824,21 +4866,27 @@ static inline int mlx_arange( const mlx_stream s) { return mlx_arange_(res, start, stop, step, dtype, s); } + static inline int mlx_arccos(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_arccos_(res, a, s); } + static inline int mlx_arccosh(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_arccosh_(res, a, s); } + static inline int mlx_arcsin(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_arcsin_(res, a, s); } + static inline int mlx_arcsinh(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_arcsinh_(res, a, s); } + static inline int mlx_arctan(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_arctan_(res, a, s); } + static inline int mlx_arctan2( mlx_array* res, const mlx_array a, @@ -4846,9 +4894,11 @@ static inline int mlx_arctan2( const mlx_stream s) { return mlx_arctan2_(res, a, b, s); } + static inline int mlx_arctanh(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_arctanh_(res, a, s); } + static inline int mlx_argmax_axis( mlx_array* res, const mlx_array a, @@ -4857,6 +4907,7 @@ static inline int mlx_argmax_axis( const mlx_stream s) { return mlx_argmax_axis_(res, a, axis, keepdims, s); } + static inline int mlx_argmax( mlx_array* res, const mlx_array a, @@ -4864,6 +4915,7 @@ static inline int mlx_argmax( const mlx_stream s) { return mlx_argmax_(res, a, keepdims, s); } + static inline int mlx_argmin_axis( mlx_array* res, const mlx_array a, @@ -4872,6 +4924,7 @@ static inline int mlx_argmin_axis( const mlx_stream s) { return mlx_argmin_axis_(res, a, axis, keepdims, s); } + static inline int mlx_argmin( mlx_array* res, const mlx_array a, @@ -4879,6 +4932,7 @@ static inline int mlx_argmin( const mlx_stream s) { return mlx_argmin_(res, a, keepdims, s); } + static inline int mlx_argpartition_axis( mlx_array* res, const mlx_array a, @@ -4887,6 +4941,7 @@ static inline int mlx_argpartition_axis( const mlx_stream s) { return mlx_argpartition_axis_(res, a, kth, axis, s); } + static inline int mlx_argpartition( mlx_array* res, const mlx_array a, @@ -4894,6 +4949,7 @@ static inline int mlx_argpartition( const mlx_stream s) { return mlx_argpartition_(res, a, kth, s); } + static inline int mlx_argsort_axis( mlx_array* res, const mlx_array a, @@ -4901,9 +4957,11 @@ static inline int mlx_argsort_axis( const mlx_stream s) { return mlx_argsort_axis_(res, a, axis, s); } + static inline int mlx_argsort(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_argsort_(res, a, s); } + static inline int mlx_array_equal( mlx_array* res, const mlx_array a, @@ -4912,6 +4970,7 @@ static inline int mlx_array_equal( const mlx_stream s) { return mlx_array_equal_(res, a, b, equal_nan, s); } + static inline int mlx_as_strided( mlx_array* res, const mlx_array a, @@ -4923,6 +4982,7 @@ static inline int mlx_as_strided( const mlx_stream s) { return mlx_as_strided_(res, a, shape, shape_num, strides, strides_num, offset, s); } + static inline int mlx_astype( mlx_array* res, const mlx_array a, @@ -4930,15 +4990,19 @@ static inline int mlx_astype( const mlx_stream s) { return mlx_astype_(res, a, dtype, s); } + static inline int mlx_atleast_1d(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_atleast_1d_(res, a, s); } + static inline int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_atleast_2d_(res, a, s); } + static inline int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_atleast_3d_(res, a, s); } + static inline int mlx_bitwise_and( mlx_array* res, const mlx_array a, @@ -4946,9 +5010,11 @@ static inline int mlx_bitwise_and( const mlx_stream s) { return mlx_bitwise_and_(res, a, b, s); } + static inline int mlx_bitwise_invert(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_bitwise_invert_(res, a, s); } + static inline int mlx_bitwise_or( mlx_array* res, const mlx_array a, @@ -4956,6 +5022,7 @@ static inline int mlx_bitwise_or( const mlx_stream s) { return mlx_bitwise_or_(res, a, b, s); } + static inline int mlx_bitwise_xor( mlx_array* res, const mlx_array a, @@ -4963,6 +5030,7 @@ static inline int mlx_bitwise_xor( const mlx_stream s) { return mlx_bitwise_xor_(res, a, b, s); } + static inline int mlx_block_masked_mm( mlx_array* res, const mlx_array a, @@ -4974,12 +5042,14 @@ static inline int mlx_block_masked_mm( const mlx_stream s) { return mlx_block_masked_mm_(res, a, b, block_size, mask_out, mask_lhs, mask_rhs, s); } + static inline int mlx_broadcast_arrays( mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s) { return mlx_broadcast_arrays_(res, inputs, s); } + static inline int mlx_broadcast_to( mlx_array* res, const mlx_array a, @@ -4988,9 +5058,11 @@ static inline int mlx_broadcast_to( const mlx_stream s) { return mlx_broadcast_to_(res, a, shape, shape_num, s); } + static inline int mlx_ceil(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_ceil_(res, a, s); } + static inline int mlx_clip( mlx_array* res, const mlx_array a, @@ -4999,6 +5071,7 @@ static inline int mlx_clip( const mlx_stream s) { return mlx_clip_(res, a, a_min, a_max, s); } + static inline int mlx_concatenate_axis( mlx_array* res, const mlx_vector_array arrays, @@ -5006,15 +5079,18 @@ static inline int mlx_concatenate_axis( const mlx_stream s) { return mlx_concatenate_axis_(res, arrays, axis, s); } + static inline int mlx_concatenate( mlx_array* res, const mlx_vector_array arrays, const mlx_stream s) { return mlx_concatenate_(res, arrays, s); } + static inline int mlx_conjugate(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_conjugate_(res, a, s); } + static inline int mlx_contiguous( mlx_array* res, const mlx_array a, @@ -5022,6 +5098,7 @@ static inline int mlx_contiguous( const mlx_stream s) { return mlx_contiguous_(res, a, allow_col_major, s); } + static inline int mlx_conv1d( mlx_array* res, const mlx_array input, @@ -5033,6 +5110,7 @@ static inline int mlx_conv1d( const mlx_stream s) { return mlx_conv1d_(res, input, weight, stride, padding, dilation, groups, s); } + static inline int mlx_conv2d( mlx_array* res, const mlx_array input, @@ -5047,6 +5125,7 @@ static inline int mlx_conv2d( const mlx_stream s) { return mlx_conv2d_(res, input, weight, stride_0, stride_1, padding_0, padding_1, dilation_0, dilation_1, groups, s); } + static inline int mlx_conv3d( mlx_array* res, const mlx_array input, @@ -5064,6 +5143,7 @@ static inline int mlx_conv3d( const mlx_stream s) { return mlx_conv3d_(res, input, weight, stride_0, stride_1, stride_2, padding_0, padding_1, padding_2, dilation_0, dilation_1, dilation_2, groups, s); } + static inline int mlx_conv_general( mlx_array* res, const mlx_array input, @@ -5083,6 +5163,7 @@ static inline int mlx_conv_general( const mlx_stream s) { return mlx_conv_general_(res, input, weight, stride, stride_num, padding_lo, padding_lo_num, padding_hi, padding_hi_num, kernel_dilation, kernel_dilation_num, input_dilation, input_dilation_num, groups, flip, s); } + static inline int mlx_conv_transpose1d( mlx_array* res, const mlx_array input, @@ -5095,6 +5176,7 @@ static inline int mlx_conv_transpose1d( const mlx_stream s) { return mlx_conv_transpose1d_(res, input, weight, stride, padding, dilation, output_padding, groups, s); } + static inline int mlx_conv_transpose2d( mlx_array* res, const mlx_array input, @@ -5111,6 +5193,7 @@ static inline int mlx_conv_transpose2d( const mlx_stream s) { return mlx_conv_transpose2d_(res, input, weight, stride_0, stride_1, padding_0, padding_1, dilation_0, dilation_1, output_padding_0, output_padding_1, groups, s); } + static inline int mlx_conv_transpose3d( mlx_array* res, const mlx_array input, @@ -5131,15 +5214,19 @@ static inline int mlx_conv_transpose3d( const mlx_stream s) { return mlx_conv_transpose3d_(res, input, weight, stride_0, stride_1, stride_2, padding_0, padding_1, padding_2, dilation_0, dilation_1, dilation_2, output_padding_0, output_padding_1, output_padding_2, groups, s); } + static inline int mlx_copy(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_copy_(res, a, s); } + static inline int mlx_cos(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_cos_(res, a, s); } + static inline int mlx_cosh(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_cosh_(res, a, s); } + static inline int mlx_cummax( mlx_array* res, const mlx_array a, @@ -5149,6 +5236,7 @@ static inline int mlx_cummax( const mlx_stream s) { return mlx_cummax_(res, a, axis, reverse, inclusive, s); } + static inline int mlx_cummin( mlx_array* res, const mlx_array a, @@ -5158,6 +5246,7 @@ static inline int mlx_cummin( const mlx_stream s) { return mlx_cummin_(res, a, axis, reverse, inclusive, s); } + static inline int mlx_cumprod( mlx_array* res, const mlx_array a, @@ -5167,6 +5256,7 @@ static inline int mlx_cumprod( const mlx_stream s) { return mlx_cumprod_(res, a, axis, reverse, inclusive, s); } + static inline int mlx_cumsum( mlx_array* res, const mlx_array a, @@ -5176,15 +5266,18 @@ static inline int mlx_cumsum( const mlx_stream s) { return mlx_cumsum_(res, a, axis, reverse, inclusive, s); } + static inline int mlx_degrees(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_degrees_(res, a, s); } + static inline int mlx_depends( mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies) { return mlx_depends_(res, inputs, dependencies); } + static inline int mlx_dequantize( mlx_array* res, const mlx_array w, @@ -5197,9 +5290,11 @@ static inline int mlx_dequantize( const mlx_stream s) { return mlx_dequantize_(res, w, scales, biases, group_size, bits, mode, dtype, s); } + static inline int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) { return mlx_diag_(res, a, k, s); } + static inline int mlx_diagonal( mlx_array* res, const mlx_array a, @@ -5209,6 +5304,7 @@ static inline int mlx_diagonal( const mlx_stream s) { return mlx_diagonal_(res, a, offset, axis1, axis2, s); } + static inline int mlx_divide( mlx_array* res, const mlx_array a, @@ -5216,6 +5312,7 @@ static inline int mlx_divide( const mlx_stream s) { return mlx_divide_(res, a, b, s); } + static inline int mlx_divmod( mlx_vector_array* res, const mlx_array a, @@ -5223,6 +5320,7 @@ static inline int mlx_divmod( const mlx_stream s) { return mlx_divmod_(res, a, b, s); } + static inline int mlx_einsum( mlx_array* res, const char* subscripts, @@ -5230,6 +5328,7 @@ static inline int mlx_einsum( const mlx_stream s) { return mlx_einsum_(res, subscripts, operands, s); } + static inline int mlx_equal( mlx_array* res, const mlx_array a, @@ -5237,15 +5336,19 @@ static inline int mlx_equal( const mlx_stream s) { return mlx_equal_(res, a, b, s); } + static inline int mlx_erf(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_erf_(res, a, s); } + static inline int mlx_erfinv(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_erfinv_(res, a, s); } + static inline int mlx_exp(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_exp_(res, a, s); } + static inline int mlx_expand_dims_axes( mlx_array* res, const mlx_array a, @@ -5254,6 +5357,7 @@ static inline int mlx_expand_dims_axes( const mlx_stream s) { return mlx_expand_dims_axes_(res, a, axes, axes_num, s); } + static inline int mlx_expand_dims( mlx_array* res, const mlx_array a, @@ -5261,9 +5365,11 @@ static inline int mlx_expand_dims( const mlx_stream s) { return mlx_expand_dims_(res, a, axis, s); } + static inline int mlx_expm1(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_expm1_(res, a, s); } + static inline int mlx_eye( mlx_array* res, int n, @@ -5273,6 +5379,7 @@ static inline int mlx_eye( const mlx_stream s) { return mlx_eye_(res, n, m, k, dtype, s); } + static inline int mlx_flatten( mlx_array* res, const mlx_array a, @@ -5281,9 +5388,11 @@ static inline int mlx_flatten( const mlx_stream s) { return mlx_flatten_(res, a, start_axis, end_axis, s); } + static inline int mlx_floor(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_floor_(res, a, s); } + static inline int mlx_floor_divide( mlx_array* res, const mlx_array a, @@ -5291,6 +5400,7 @@ static inline int mlx_floor_divide( const mlx_stream s) { return mlx_floor_divide_(res, a, b, s); } + static inline int mlx_from_fp8( mlx_array* res, const mlx_array x, @@ -5298,6 +5408,7 @@ static inline int mlx_from_fp8( const mlx_stream s) { return mlx_from_fp8_(res, x, dtype, s); } + static inline int mlx_full( mlx_array* res, const int* shape, @@ -5307,6 +5418,7 @@ static inline int mlx_full( const mlx_stream s) { return mlx_full_(res, shape, shape_num, vals, dtype, s); } + static inline int mlx_full_like( mlx_array* res, const mlx_array a, @@ -5315,6 +5427,7 @@ static inline int mlx_full_like( const mlx_stream s) { return mlx_full_like_(res, a, vals, dtype, s); } + static inline int mlx_gather( mlx_array* res, const mlx_array a, @@ -5326,16 +5439,7 @@ static inline int mlx_gather( const mlx_stream s) { return mlx_gather_(res, a, indices, axes, axes_num, slice_sizes, slice_sizes_num, s); } -static inline int mlx_gather_single( - mlx_array* res, - const mlx_array a, - const mlx_array indices, - int axis, - const int* slice_sizes, - size_t slice_sizes_num, - const mlx_stream s) { - return mlx_gather_single_(res, a, indices, axis, slice_sizes, slice_sizes_num, s); -} + static inline int mlx_gather_mm( mlx_array* res, const mlx_array a, @@ -5346,6 +5450,7 @@ static inline int mlx_gather_mm( const mlx_stream s) { return mlx_gather_mm_(res, a, b, lhs_indices, rhs_indices, sorted_indices, s); } + static inline int mlx_gather_qmm( mlx_array* res, const mlx_array x, @@ -5362,6 +5467,7 @@ static inline int mlx_gather_qmm( const mlx_stream s) { return mlx_gather_qmm_(res, x, w, scales, biases, lhs_indices, rhs_indices, transpose, group_size, bits, mode, sorted_indices, s); } + static inline int mlx_greater( mlx_array* res, const mlx_array a, @@ -5369,6 +5475,7 @@ static inline int mlx_greater( const mlx_stream s) { return mlx_greater_(res, a, b, s); } + static inline int mlx_greater_equal( mlx_array* res, const mlx_array a, @@ -5376,6 +5483,7 @@ static inline int mlx_greater_equal( const mlx_stream s) { return mlx_greater_equal_(res, a, b, s); } + static inline int mlx_hadamard_transform( mlx_array* res, const mlx_array a, @@ -5383,12 +5491,15 @@ static inline int mlx_hadamard_transform( const mlx_stream s) { return mlx_hadamard_transform_(res, a, scale, s); } + static inline int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) { return mlx_identity_(res, n, dtype, s); } + static inline int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_imag_(res, a, s); } + static inline int mlx_inner( mlx_array* res, const mlx_array a, @@ -5396,6 +5507,7 @@ static inline int mlx_inner( const mlx_stream s) { return mlx_inner_(res, a, b, s); } + static inline int mlx_isclose( mlx_array* res, const mlx_array a, @@ -5406,21 +5518,27 @@ static inline int mlx_isclose( const mlx_stream s) { return mlx_isclose_(res, a, b, rtol, atol, equal_nan, s); } + static inline int mlx_isfinite(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_isfinite_(res, a, s); } + static inline int mlx_isinf(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_isinf_(res, a, s); } + static inline int mlx_isnan(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_isnan_(res, a, s); } + static inline int mlx_isneginf(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_isneginf_(res, a, s); } + static inline int mlx_isposinf(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_isposinf_(res, a, s); } + static inline int mlx_kron( mlx_array* res, const mlx_array a, @@ -5428,6 +5546,7 @@ static inline int mlx_kron( const mlx_stream s) { return mlx_kron_(res, a, b, s); } + static inline int mlx_left_shift( mlx_array* res, const mlx_array a, @@ -5435,6 +5554,7 @@ static inline int mlx_left_shift( const mlx_stream s) { return mlx_left_shift_(res, a, b, s); } + static inline int mlx_less( mlx_array* res, const mlx_array a, @@ -5442,6 +5562,7 @@ static inline int mlx_less( const mlx_stream s) { return mlx_less_(res, a, b, s); } + static inline int mlx_less_equal( mlx_array* res, const mlx_array a, @@ -5449,6 +5570,7 @@ static inline int mlx_less_equal( const mlx_stream s) { return mlx_less_equal_(res, a, b, s); } + static inline int mlx_linspace( mlx_array* res, double start, @@ -5458,18 +5580,23 @@ static inline int mlx_linspace( const mlx_stream s) { return mlx_linspace_(res, start, stop, num, dtype, s); } + static inline int mlx_log(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_log_(res, a, s); } + static inline int mlx_log10(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_log10_(res, a, s); } + static inline int mlx_log1p(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_log1p_(res, a, s); } + static inline int mlx_log2(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_log2_(res, a, s); } + static inline int mlx_logaddexp( mlx_array* res, const mlx_array a, @@ -5477,6 +5604,7 @@ static inline int mlx_logaddexp( const mlx_stream s) { return mlx_logaddexp_(res, a, b, s); } + static inline int mlx_logcumsumexp( mlx_array* res, const mlx_array a, @@ -5486,6 +5614,7 @@ static inline int mlx_logcumsumexp( const mlx_stream s) { return mlx_logcumsumexp_(res, a, axis, reverse, inclusive, s); } + static inline int mlx_logical_and( mlx_array* res, const mlx_array a, @@ -5493,9 +5622,11 @@ static inline int mlx_logical_and( const mlx_stream s) { return mlx_logical_and_(res, a, b, s); } + static inline int mlx_logical_not(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_logical_not_(res, a, s); } + static inline int mlx_logical_or( mlx_array* res, const mlx_array a, @@ -5503,6 +5634,7 @@ static inline int mlx_logical_or( const mlx_stream s) { return mlx_logical_or_(res, a, b, s); } + static inline int mlx_logsumexp_axes( mlx_array* res, const mlx_array a, @@ -5512,6 +5644,7 @@ static inline int mlx_logsumexp_axes( const mlx_stream s) { return mlx_logsumexp_axes_(res, a, axes, axes_num, keepdims, s); } + static inline int mlx_logsumexp_axis( mlx_array* res, const mlx_array a, @@ -5520,6 +5653,7 @@ static inline int mlx_logsumexp_axis( const mlx_stream s) { return mlx_logsumexp_axis_(res, a, axis, keepdims, s); } + static inline int mlx_logsumexp( mlx_array* res, const mlx_array a, @@ -5527,6 +5661,7 @@ static inline int mlx_logsumexp( const mlx_stream s) { return mlx_logsumexp_(res, a, keepdims, s); } + static inline int mlx_masked_scatter( mlx_array* res, const mlx_array a, @@ -5535,6 +5670,7 @@ static inline int mlx_masked_scatter( const mlx_stream s) { return mlx_masked_scatter_(res, a, mask, src, s); } + static inline int mlx_matmul( mlx_array* res, const mlx_array a, @@ -5542,6 +5678,7 @@ static inline int mlx_matmul( const mlx_stream s) { return mlx_matmul_(res, a, b, s); } + static inline int mlx_max_axes( mlx_array* res, const mlx_array a, @@ -5551,6 +5688,7 @@ static inline int mlx_max_axes( const mlx_stream s) { return mlx_max_axes_(res, a, axes, axes_num, keepdims, s); } + static inline int mlx_max_axis( mlx_array* res, const mlx_array a, @@ -5559,6 +5697,7 @@ static inline int mlx_max_axis( const mlx_stream s) { return mlx_max_axis_(res, a, axis, keepdims, s); } + static inline int mlx_max( mlx_array* res, const mlx_array a, @@ -5566,6 +5705,7 @@ static inline int mlx_max( const mlx_stream s) { return mlx_max_(res, a, keepdims, s); } + static inline int mlx_maximum( mlx_array* res, const mlx_array a, @@ -5573,6 +5713,7 @@ static inline int mlx_maximum( const mlx_stream s) { return mlx_maximum_(res, a, b, s); } + static inline int mlx_mean_axes( mlx_array* res, const mlx_array a, @@ -5582,6 +5723,7 @@ static inline int mlx_mean_axes( const mlx_stream s) { return mlx_mean_axes_(res, a, axes, axes_num, keepdims, s); } + static inline int mlx_mean_axis( mlx_array* res, const mlx_array a, @@ -5590,6 +5732,7 @@ static inline int mlx_mean_axis( const mlx_stream s) { return mlx_mean_axis_(res, a, axis, keepdims, s); } + static inline int mlx_mean( mlx_array* res, const mlx_array a, @@ -5597,6 +5740,7 @@ static inline int mlx_mean( const mlx_stream s) { return mlx_mean_(res, a, keepdims, s); } + static inline int mlx_median( mlx_array* res, const mlx_array a, @@ -5606,6 +5750,7 @@ static inline int mlx_median( const mlx_stream s) { return mlx_median_(res, a, axes, axes_num, keepdims, s); } + static inline int mlx_meshgrid( mlx_vector_array* res, const mlx_vector_array arrays, @@ -5614,6 +5759,7 @@ static inline int mlx_meshgrid( const mlx_stream s) { return mlx_meshgrid_(res, arrays, sparse, indexing, s); } + static inline int mlx_min_axes( mlx_array* res, const mlx_array a, @@ -5623,6 +5769,7 @@ static inline int mlx_min_axes( const mlx_stream s) { return mlx_min_axes_(res, a, axes, axes_num, keepdims, s); } + static inline int mlx_min_axis( mlx_array* res, const mlx_array a, @@ -5631,6 +5778,7 @@ static inline int mlx_min_axis( const mlx_stream s) { return mlx_min_axis_(res, a, axis, keepdims, s); } + static inline int mlx_min( mlx_array* res, const mlx_array a, @@ -5638,6 +5786,7 @@ static inline int mlx_min( const mlx_stream s) { return mlx_min_(res, a, keepdims, s); } + static inline int mlx_minimum( mlx_array* res, const mlx_array a, @@ -5645,6 +5794,7 @@ static inline int mlx_minimum( const mlx_stream s) { return mlx_minimum_(res, a, b, s); } + static inline int mlx_moveaxis( mlx_array* res, const mlx_array a, @@ -5653,6 +5803,7 @@ static inline int mlx_moveaxis( const mlx_stream s) { return mlx_moveaxis_(res, a, source, destination, s); } + static inline int mlx_multiply( mlx_array* res, const mlx_array a, @@ -5660,6 +5811,7 @@ static inline int mlx_multiply( const mlx_stream s) { return mlx_multiply_(res, a, b, s); } + static inline int mlx_nan_to_num( mlx_array* res, const mlx_array a, @@ -5669,9 +5821,11 @@ static inline int mlx_nan_to_num( const mlx_stream s) { return mlx_nan_to_num_(res, a, nan, posinf, neginf, s); } + static inline int mlx_negative(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_negative_(res, a, s); } + static inline int mlx_not_equal( mlx_array* res, const mlx_array a, @@ -5679,6 +5833,7 @@ static inline int mlx_not_equal( const mlx_stream s) { return mlx_not_equal_(res, a, b, s); } + static inline int mlx_number_of_elements( mlx_array* res, const mlx_array a, @@ -5689,6 +5844,7 @@ static inline int mlx_number_of_elements( const mlx_stream s) { return mlx_number_of_elements_(res, a, axes, axes_num, inverted, dtype, s); } + static inline int mlx_ones( mlx_array* res, const int* shape, @@ -5697,9 +5853,11 @@ static inline int mlx_ones( const mlx_stream s) { return mlx_ones_(res, shape, shape_num, dtype, s); } + static inline int mlx_ones_like(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_ones_like_(res, a, s); } + static inline int mlx_outer( mlx_array* res, const mlx_array a, @@ -5707,6 +5865,7 @@ static inline int mlx_outer( const mlx_stream s) { return mlx_outer_(res, a, b, s); } + static inline int mlx_pad( mlx_array* res, const mlx_array a, @@ -5721,6 +5880,7 @@ static inline int mlx_pad( const mlx_stream s) { return mlx_pad_(res, a, axes, axes_num, low_pad_size, low_pad_size_num, high_pad_size, high_pad_size_num, pad_value, mode, s); } + static inline int mlx_pad_symmetric( mlx_array* res, const mlx_array a, @@ -5730,6 +5890,7 @@ static inline int mlx_pad_symmetric( const mlx_stream s) { return mlx_pad_symmetric_(res, a, pad_width, pad_value, mode, s); } + static inline int mlx_partition_axis( mlx_array* res, const mlx_array a, @@ -5738,6 +5899,7 @@ static inline int mlx_partition_axis( const mlx_stream s) { return mlx_partition_axis_(res, a, kth, axis, s); } + static inline int mlx_partition( mlx_array* res, const mlx_array a, @@ -5745,6 +5907,7 @@ static inline int mlx_partition( const mlx_stream s) { return mlx_partition_(res, a, kth, s); } + static inline int mlx_power( mlx_array* res, const mlx_array a, @@ -5752,6 +5915,7 @@ static inline int mlx_power( const mlx_stream s) { return mlx_power_(res, a, b, s); } + static inline int mlx_prod_axes( mlx_array* res, const mlx_array a, @@ -5761,6 +5925,7 @@ static inline int mlx_prod_axes( const mlx_stream s) { return mlx_prod_axes_(res, a, axes, axes_num, keepdims, s); } + static inline int mlx_prod_axis( mlx_array* res, const mlx_array a, @@ -5769,6 +5934,7 @@ static inline int mlx_prod_axis( const mlx_stream s) { return mlx_prod_axis_(res, a, axis, keepdims, s); } + static inline int mlx_prod( mlx_array* res, const mlx_array a, @@ -5776,6 +5942,7 @@ static inline int mlx_prod( const mlx_stream s) { return mlx_prod_(res, a, keepdims, s); } + static inline int mlx_put_along_axis( mlx_array* res, const mlx_array a, @@ -5785,17 +5952,7 @@ static inline int mlx_put_along_axis( const mlx_stream s) { return mlx_put_along_axis_(res, a, indices, values, axis, s); } -static inline int mlx_qqmm( - mlx_array* res, - const mlx_array x, - const mlx_array w, - const mlx_array w_scales /* may be null */, - mlx_optional_int group_size, - mlx_optional_int bits, - const char* mode, - const mlx_stream s) { - return mlx_qqmm_(res, x, w, w_scales, group_size, bits, mode, s); -} + static inline int mlx_quantize( mlx_vector_array* res, const mlx_array w, @@ -5805,6 +5962,7 @@ static inline int mlx_quantize( const mlx_stream s) { return mlx_quantize_(res, w, group_size, bits, mode, s); } + static inline int mlx_quantized_matmul( mlx_array* res, const mlx_array x, @@ -5818,15 +5976,19 @@ static inline int mlx_quantized_matmul( const mlx_stream s) { return mlx_quantized_matmul_(res, x, w, scales, biases, transpose, group_size, bits, mode, s); } + static inline int mlx_radians(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_radians_(res, a, s); } + static inline int mlx_real(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_real_(res, a, s); } + static inline int mlx_reciprocal(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_reciprocal_(res, a, s); } + static inline int mlx_remainder( mlx_array* res, const mlx_array a, @@ -5834,6 +5996,7 @@ static inline int mlx_remainder( const mlx_stream s) { return mlx_remainder_(res, a, b, s); } + static inline int mlx_repeat_axis( mlx_array* res, const mlx_array arr, @@ -5842,6 +6005,7 @@ static inline int mlx_repeat_axis( const mlx_stream s) { return mlx_repeat_axis_(res, arr, repeats, axis, s); } + static inline int mlx_repeat( mlx_array* res, const mlx_array arr, @@ -5849,6 +6013,7 @@ static inline int mlx_repeat( const mlx_stream s) { return mlx_repeat_(res, arr, repeats, s); } + static inline int mlx_reshape( mlx_array* res, const mlx_array a, @@ -5857,6 +6022,7 @@ static inline int mlx_reshape( const mlx_stream s) { return mlx_reshape_(res, a, shape, shape_num, s); } + static inline int mlx_right_shift( mlx_array* res, const mlx_array a, @@ -5864,6 +6030,7 @@ static inline int mlx_right_shift( const mlx_stream s) { return mlx_right_shift_(res, a, b, s); } + static inline int mlx_roll_axis( mlx_array* res, const mlx_array a, @@ -5873,6 +6040,7 @@ static inline int mlx_roll_axis( const mlx_stream s) { return mlx_roll_axis_(res, a, shift, shift_num, axis, s); } + static inline int mlx_roll_axes( mlx_array* res, const mlx_array a, @@ -5883,6 +6051,7 @@ static inline int mlx_roll_axes( const mlx_stream s) { return mlx_roll_axes_(res, a, shift, shift_num, axes, axes_num, s); } + static inline int mlx_roll( mlx_array* res, const mlx_array a, @@ -5891,6 +6060,7 @@ static inline int mlx_roll( const mlx_stream s) { return mlx_roll_(res, a, shift, shift_num, s); } + static inline int mlx_round( mlx_array* res, const mlx_array a, @@ -5898,9 +6068,11 @@ static inline int mlx_round( const mlx_stream s) { return mlx_round_(res, a, decimals, s); } + static inline int mlx_rsqrt(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_rsqrt_(res, a, s); } + static inline int mlx_scatter( mlx_array* res, const mlx_array a, @@ -5911,15 +6083,7 @@ static inline int mlx_scatter( const mlx_stream s) { return mlx_scatter_(res, a, indices, updates, axes, axes_num, s); } -static inline int mlx_scatter_single( - mlx_array* res, - const mlx_array a, - const mlx_array indices, - const mlx_array updates, - int axis, - const mlx_stream s) { - return mlx_scatter_single_(res, a, indices, updates, axis, s); -} + static inline int mlx_scatter_add( mlx_array* res, const mlx_array a, @@ -5930,15 +6094,7 @@ static inline int mlx_scatter_add( const mlx_stream s) { return mlx_scatter_add_(res, a, indices, updates, axes, axes_num, s); } -static inline int mlx_scatter_add_single( - mlx_array* res, - const mlx_array a, - const mlx_array indices, - const mlx_array updates, - int axis, - const mlx_stream s) { - return mlx_scatter_add_single_(res, a, indices, updates, axis, s); -} + static inline int mlx_scatter_add_axis( mlx_array* res, const mlx_array a, @@ -5948,6 +6104,7 @@ static inline int mlx_scatter_add_axis( const mlx_stream s) { return mlx_scatter_add_axis_(res, a, indices, values, axis, s); } + static inline int mlx_scatter_max( mlx_array* res, const mlx_array a, @@ -5958,15 +6115,7 @@ static inline int mlx_scatter_max( const mlx_stream s) { return mlx_scatter_max_(res, a, indices, updates, axes, axes_num, s); } -static inline int mlx_scatter_max_single( - mlx_array* res, - const mlx_array a, - const mlx_array indices, - const mlx_array updates, - int axis, - const mlx_stream s) { - return mlx_scatter_max_single_(res, a, indices, updates, axis, s); -} + static inline int mlx_scatter_min( mlx_array* res, const mlx_array a, @@ -5977,15 +6126,7 @@ static inline int mlx_scatter_min( const mlx_stream s) { return mlx_scatter_min_(res, a, indices, updates, axes, axes_num, s); } -static inline int mlx_scatter_min_single( - mlx_array* res, - const mlx_array a, - const mlx_array indices, - const mlx_array updates, - int axis, - const mlx_stream s) { - return mlx_scatter_min_single_(res, a, indices, updates, axis, s); -} + static inline int mlx_scatter_prod( mlx_array* res, const mlx_array a, @@ -5996,15 +6137,7 @@ static inline int mlx_scatter_prod( const mlx_stream s) { return mlx_scatter_prod_(res, a, indices, updates, axes, axes_num, s); } -static inline int mlx_scatter_prod_single( - mlx_array* res, - const mlx_array a, - const mlx_array indices, - const mlx_array updates, - int axis, - const mlx_stream s) { - return mlx_scatter_prod_single_(res, a, indices, updates, axis, s); -} + static inline int mlx_segmented_mm( mlx_array* res, const mlx_array a, @@ -6013,18 +6146,23 @@ static inline int mlx_segmented_mm( const mlx_stream s) { return mlx_segmented_mm_(res, a, b, segments, s); } + static inline int mlx_sigmoid(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_sigmoid_(res, a, s); } + static inline int mlx_sign(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_sign_(res, a, s); } + static inline int mlx_sin(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_sin_(res, a, s); } + static inline int mlx_sinh(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_sinh_(res, a, s); } + static inline int mlx_slice( mlx_array* res, const mlx_array a, @@ -6037,6 +6175,7 @@ static inline int mlx_slice( const mlx_stream s) { return mlx_slice_(res, a, start, start_num, stop, stop_num, strides, strides_num, s); } + static inline int mlx_slice_dynamic( mlx_array* res, const mlx_array a, @@ -6048,6 +6187,7 @@ static inline int mlx_slice_dynamic( const mlx_stream s) { return mlx_slice_dynamic_(res, a, start, axes, axes_num, slice_size, slice_size_num, s); } + static inline int mlx_slice_update( mlx_array* res, const mlx_array src, @@ -6061,6 +6201,7 @@ static inline int mlx_slice_update( const mlx_stream s) { return mlx_slice_update_(res, src, update, start, start_num, stop, stop_num, strides, strides_num, s); } + static inline int mlx_slice_update_dynamic( mlx_array* res, const mlx_array src, @@ -6071,6 +6212,7 @@ static inline int mlx_slice_update_dynamic( const mlx_stream s) { return mlx_slice_update_dynamic_(res, src, update, start, axes, axes_num, s); } + static inline int mlx_softmax_axes( mlx_array* res, const mlx_array a, @@ -6080,6 +6222,7 @@ static inline int mlx_softmax_axes( const mlx_stream s) { return mlx_softmax_axes_(res, a, axes, axes_num, precise, s); } + static inline int mlx_softmax_axis( mlx_array* res, const mlx_array a, @@ -6088,6 +6231,7 @@ static inline int mlx_softmax_axis( const mlx_stream s) { return mlx_softmax_axis_(res, a, axis, precise, s); } + static inline int mlx_softmax( mlx_array* res, const mlx_array a, @@ -6095,6 +6239,7 @@ static inline int mlx_softmax( const mlx_stream s) { return mlx_softmax_(res, a, precise, s); } + static inline int mlx_sort_axis( mlx_array* res, const mlx_array a, @@ -6102,9 +6247,11 @@ static inline int mlx_sort_axis( const mlx_stream s) { return mlx_sort_axis_(res, a, axis, s); } + static inline int mlx_sort(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_sort_(res, a, s); } + static inline int mlx_split( mlx_vector_array* res, const mlx_array a, @@ -6113,6 +6260,7 @@ static inline int mlx_split( const mlx_stream s) { return mlx_split_(res, a, num_splits, axis, s); } + static inline int mlx_split_sections( mlx_vector_array* res, const mlx_array a, @@ -6122,12 +6270,15 @@ static inline int mlx_split_sections( const mlx_stream s) { return mlx_split_sections_(res, a, indices, indices_num, axis, s); } + static inline int mlx_sqrt(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_sqrt_(res, a, s); } + static inline int mlx_square(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_square_(res, a, s); } + static inline int mlx_squeeze_axes( mlx_array* res, const mlx_array a, @@ -6136,6 +6287,7 @@ static inline int mlx_squeeze_axes( const mlx_stream s) { return mlx_squeeze_axes_(res, a, axes, axes_num, s); } + static inline int mlx_squeeze_axis( mlx_array* res, const mlx_array a, @@ -6143,9 +6295,11 @@ static inline int mlx_squeeze_axis( const mlx_stream s) { return mlx_squeeze_axis_(res, a, axis, s); } + static inline int mlx_squeeze(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_squeeze_(res, a, s); } + static inline int mlx_stack_axis( mlx_array* res, const mlx_vector_array arrays, @@ -6153,12 +6307,14 @@ static inline int mlx_stack_axis( const mlx_stream s) { return mlx_stack_axis_(res, arrays, axis, s); } + static inline int mlx_stack( mlx_array* res, const mlx_vector_array arrays, const mlx_stream s) { return mlx_stack_(res, arrays, s); } + static inline int mlx_std_axes( mlx_array* res, const mlx_array a, @@ -6169,6 +6325,7 @@ static inline int mlx_std_axes( const mlx_stream s) { return mlx_std_axes_(res, a, axes, axes_num, keepdims, ddof, s); } + static inline int mlx_std_axis( mlx_array* res, const mlx_array a, @@ -6178,6 +6335,7 @@ static inline int mlx_std_axis( const mlx_stream s) { return mlx_std_axis_(res, a, axis, keepdims, ddof, s); } + static inline int mlx_std( mlx_array* res, const mlx_array a, @@ -6186,9 +6344,11 @@ static inline int mlx_std( const mlx_stream s) { return mlx_std_(res, a, keepdims, ddof, s); } + static inline int mlx_stop_gradient(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_stop_gradient_(res, a, s); } + static inline int mlx_subtract( mlx_array* res, const mlx_array a, @@ -6196,6 +6356,7 @@ static inline int mlx_subtract( const mlx_stream s) { return mlx_subtract_(res, a, b, s); } + static inline int mlx_sum_axes( mlx_array* res, const mlx_array a, @@ -6205,6 +6366,7 @@ static inline int mlx_sum_axes( const mlx_stream s) { return mlx_sum_axes_(res, a, axes, axes_num, keepdims, s); } + static inline int mlx_sum_axis( mlx_array* res, const mlx_array a, @@ -6213,6 +6375,7 @@ static inline int mlx_sum_axis( const mlx_stream s) { return mlx_sum_axis_(res, a, axis, keepdims, s); } + static inline int mlx_sum( mlx_array* res, const mlx_array a, @@ -6220,6 +6383,7 @@ static inline int mlx_sum( const mlx_stream s) { return mlx_sum_(res, a, keepdims, s); } + static inline int mlx_swapaxes( mlx_array* res, const mlx_array a, @@ -6228,6 +6392,7 @@ static inline int mlx_swapaxes( const mlx_stream s) { return mlx_swapaxes_(res, a, axis1, axis2, s); } + static inline int mlx_take_axis( mlx_array* res, const mlx_array a, @@ -6236,6 +6401,7 @@ static inline int mlx_take_axis( const mlx_stream s) { return mlx_take_axis_(res, a, indices, axis, s); } + static inline int mlx_take( mlx_array* res, const mlx_array a, @@ -6243,6 +6409,7 @@ static inline int mlx_take( const mlx_stream s) { return mlx_take_(res, a, indices, s); } + static inline int mlx_take_along_axis( mlx_array* res, const mlx_array a, @@ -6251,12 +6418,15 @@ static inline int mlx_take_along_axis( const mlx_stream s) { return mlx_take_along_axis_(res, a, indices, axis, s); } + static inline int mlx_tan(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_tan_(res, a, s); } + static inline int mlx_tanh(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_tanh_(res, a, s); } + static inline int mlx_tensordot( mlx_array* res, const mlx_array a, @@ -6268,6 +6438,7 @@ static inline int mlx_tensordot( const mlx_stream s) { return mlx_tensordot_(res, a, b, axes_a, axes_a_num, axes_b, axes_b_num, s); } + static inline int mlx_tensordot_axis( mlx_array* res, const mlx_array a, @@ -6276,6 +6447,7 @@ static inline int mlx_tensordot_axis( const mlx_stream s) { return mlx_tensordot_axis_(res, a, b, axis, s); } + static inline int mlx_tile( mlx_array* res, const mlx_array arr, @@ -6284,9 +6456,11 @@ static inline int mlx_tile( const mlx_stream s) { return mlx_tile_(res, arr, reps, reps_num, s); } + static inline int mlx_to_fp8(mlx_array* res, const mlx_array x, const mlx_stream s) { return mlx_to_fp8_(res, x, s); } + static inline int mlx_topk_axis( mlx_array* res, const mlx_array a, @@ -6295,9 +6469,11 @@ static inline int mlx_topk_axis( const mlx_stream s) { return mlx_topk_axis_(res, a, k, axis, s); } + static inline int mlx_topk(mlx_array* res, const mlx_array a, int k, const mlx_stream s) { return mlx_topk_(res, a, k, s); } + static inline int mlx_trace( mlx_array* res, const mlx_array a, @@ -6308,6 +6484,7 @@ static inline int mlx_trace( const mlx_stream s) { return mlx_trace_(res, a, offset, axis1, axis2, dtype, s); } + static inline int mlx_transpose_axes( mlx_array* res, const mlx_array a, @@ -6316,9 +6493,11 @@ static inline int mlx_transpose_axes( const mlx_stream s) { return mlx_transpose_axes_(res, a, axes, axes_num, s); } + static inline int mlx_transpose(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_transpose_(res, a, s); } + static inline int mlx_tri( mlx_array* res, int n, @@ -6328,12 +6507,15 @@ static inline int mlx_tri( const mlx_stream s) { return mlx_tri_(res, n, m, k, type, s); } + static inline int mlx_tril(mlx_array* res, const mlx_array x, int k, const mlx_stream s) { return mlx_tril_(res, x, k, s); } + static inline int mlx_triu(mlx_array* res, const mlx_array x, int k, const mlx_stream s) { return mlx_triu_(res, x, k, s); } + static inline int mlx_unflatten( mlx_array* res, const mlx_array a, @@ -6343,6 +6525,7 @@ static inline int mlx_unflatten( const mlx_stream s) { return mlx_unflatten_(res, a, axis, shape, shape_num, s); } + static inline int mlx_var_axes( mlx_array* res, const mlx_array a, @@ -6353,6 +6536,7 @@ static inline int mlx_var_axes( const mlx_stream s) { return mlx_var_axes_(res, a, axes, axes_num, keepdims, ddof, s); } + static inline int mlx_var_axis( mlx_array* res, const mlx_array a, @@ -6362,6 +6546,7 @@ static inline int mlx_var_axis( const mlx_stream s) { return mlx_var_axis_(res, a, axis, keepdims, ddof, s); } + static inline int mlx_var( mlx_array* res, const mlx_array a, @@ -6370,6 +6555,7 @@ static inline int mlx_var( const mlx_stream s) { return mlx_var_(res, a, keepdims, ddof, s); } + static inline int mlx_view( mlx_array* res, const mlx_array a, @@ -6377,6 +6563,7 @@ static inline int mlx_view( const mlx_stream s) { return mlx_view_(res, a, dtype, s); } + static inline int mlx_where( mlx_array* res, const mlx_array condition, @@ -6385,6 +6572,7 @@ static inline int mlx_where( const mlx_stream s) { return mlx_where_(res, condition, x, y, s); } + static inline int mlx_zeros( mlx_array* res, const int* shape, @@ -6393,9 +6581,11 @@ static inline int mlx_zeros( const mlx_stream s) { return mlx_zeros_(res, shape, shape_num, dtype, s); } + static inline int mlx_zeros_like(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_zeros_like_(res, a, s); } + static inline int mlx_random_bernoulli( mlx_array* res, const mlx_array p, @@ -6405,6 +6595,7 @@ static inline int mlx_random_bernoulli( const mlx_stream s) { return mlx_random_bernoulli_(res, p, shape, shape_num, key, s); } + static inline int mlx_random_bits( mlx_array* res, const int* shape, @@ -6414,6 +6605,7 @@ static inline int mlx_random_bits( const mlx_stream s) { return mlx_random_bits_(res, shape, shape_num, width, key, s); } + static inline int mlx_random_categorical_shape( mlx_array* res, const mlx_array logits, @@ -6424,6 +6616,7 @@ static inline int mlx_random_categorical_shape( const mlx_stream s) { return mlx_random_categorical_shape_(res, logits, axis, shape, shape_num, key, s); } + static inline int mlx_random_categorical_num_samples( mlx_array* res, const mlx_array logits_, @@ -6433,6 +6626,7 @@ static inline int mlx_random_categorical_num_samples( const mlx_stream s) { return mlx_random_categorical_num_samples_(res, logits_, axis, num_samples, key, s); } + static inline int mlx_random_categorical( mlx_array* res, const mlx_array logits, @@ -6441,6 +6635,7 @@ static inline int mlx_random_categorical( const mlx_stream s) { return mlx_random_categorical_(res, logits, axis, key, s); } + static inline int mlx_random_gumbel( mlx_array* res, const int* shape, @@ -6450,9 +6645,11 @@ static inline int mlx_random_gumbel( const mlx_stream s) { return mlx_random_gumbel_(res, shape, shape_num, dtype, key, s); } + static inline int mlx_random_key(mlx_array* res, uint64_t seed) { return mlx_random_key_(res, seed); } + static inline int mlx_random_laplace( mlx_array* res, const int* shape, @@ -6464,6 +6661,7 @@ static inline int mlx_random_laplace( const mlx_stream s) { return mlx_random_laplace_(res, shape, shape_num, dtype, loc, scale, key, s); } + static inline int mlx_random_multivariate_normal( mlx_array* res, const mlx_array mean, @@ -6475,6 +6673,7 @@ static inline int mlx_random_multivariate_normal( const mlx_stream s) { return mlx_random_multivariate_normal_(res, mean, cov, shape, shape_num, dtype, key, s); } + static inline int mlx_random_normal_broadcast( mlx_array* res, const int* shape, @@ -6486,6 +6685,7 @@ static inline int mlx_random_normal_broadcast( const mlx_stream s) { return mlx_random_normal_broadcast_(res, shape, shape_num, dtype, loc, scale, key, s); } + static inline int mlx_random_normal( mlx_array* res, const int* shape, @@ -6497,6 +6697,7 @@ static inline int mlx_random_normal( const mlx_stream s) { return mlx_random_normal_(res, shape, shape_num, dtype, loc, scale, key, s); } + static inline int mlx_random_permutation( mlx_array* res, const mlx_array x, @@ -6505,6 +6706,7 @@ static inline int mlx_random_permutation( const mlx_stream s) { return mlx_random_permutation_(res, x, axis, key, s); } + static inline int mlx_random_permutation_arange( mlx_array* res, int x, @@ -6512,6 +6714,7 @@ static inline int mlx_random_permutation_arange( const mlx_stream s) { return mlx_random_permutation_arange_(res, x, key, s); } + static inline int mlx_random_randint( mlx_array* res, const mlx_array low, @@ -6523,9 +6726,11 @@ static inline int mlx_random_randint( const mlx_stream s) { return mlx_random_randint_(res, low, high, shape, shape_num, dtype, key, s); } + static inline int mlx_random_seed(uint64_t seed) { return mlx_random_seed_(seed); } + static inline int mlx_random_split_num( mlx_array* res, const mlx_array key, @@ -6533,6 +6738,7 @@ static inline int mlx_random_split_num( const mlx_stream s) { return mlx_random_split_num_(res, key, num, s); } + static inline int mlx_random_split( mlx_array* res_0, mlx_array* res_1, @@ -6540,6 +6746,7 @@ static inline int mlx_random_split( const mlx_stream s) { return mlx_random_split_(res_0, res_1, key, s); } + static inline int mlx_random_truncated_normal( mlx_array* res, const mlx_array lower, @@ -6551,6 +6758,7 @@ static inline int mlx_random_truncated_normal( const mlx_stream s) { return mlx_random_truncated_normal_(res, lower, upper, shape, shape_num, dtype, key, s); } + static inline int mlx_random_uniform( mlx_array* res, const mlx_array low, @@ -6562,106 +6770,79 @@ static inline int mlx_random_uniform( const mlx_stream s) { return mlx_random_uniform_(res, low, high, shape, shape_num, dtype, key, s); } + static inline mlx_stream mlx_stream_new(void) { return mlx_stream_new_(); } + static inline mlx_stream mlx_stream_new_device(mlx_device dev) { return mlx_stream_new_device_(dev); } + static inline int mlx_stream_set(mlx_stream* stream, const mlx_stream src) { return mlx_stream_set_(stream, src); } + static inline int mlx_stream_free(mlx_stream stream) { return mlx_stream_free_(stream); } + static inline int mlx_stream_tostring(mlx_string* str, mlx_stream stream) { return mlx_stream_tostring_(str, stream); } + static inline bool mlx_stream_equal(mlx_stream lhs, mlx_stream rhs) { return mlx_stream_equal_(lhs, rhs); } + static inline int mlx_stream_get_device(mlx_device* dev, mlx_stream stream) { return mlx_stream_get_device_(dev, stream); } + static inline int mlx_stream_get_index(int* index, mlx_stream stream) { return mlx_stream_get_index_(index, stream); } + static inline int mlx_synchronize(mlx_stream stream) { return mlx_synchronize_(stream); } + static inline int mlx_get_default_stream(mlx_stream* stream, mlx_device dev) { return mlx_get_default_stream_(stream, dev); } + static inline int mlx_set_default_stream(mlx_stream stream) { return mlx_set_default_stream_(stream); } + static inline mlx_stream mlx_default_cpu_stream_new(void) { return mlx_default_cpu_stream_new_(); } + static inline mlx_stream mlx_default_gpu_stream_new(void) { return mlx_default_gpu_stream_new_(); } + static inline mlx_string mlx_string_new(void) { return mlx_string_new_(); } + static inline mlx_string mlx_string_new_data(const char* str) { return mlx_string_new_data_(str); } + static inline int mlx_string_set(mlx_string* str, const mlx_string src) { return mlx_string_set_(str, src); } + static inline const char * mlx_string_data(mlx_string str) { return mlx_string_data_(str); } + static inline int mlx_string_free(mlx_string str) { return mlx_string_free_(str); } -static inline int mlx_async_eval(const mlx_vector_array outputs) { - return mlx_async_eval_(outputs); -} -static inline int mlx_checkpoint(mlx_closure* res, const mlx_closure fun) { - return mlx_checkpoint_(res, fun); -} -static inline int mlx_custom_function( - mlx_closure* res, - const mlx_closure fun, - const mlx_closure_custom fun_vjp /* may be null */, - const mlx_closure_custom_jvp fun_jvp /* may be null */, - const mlx_closure_custom_vmap fun_vmap /* may be null */) { - return mlx_custom_function_(res, fun, fun_vjp, fun_jvp, fun_vmap); -} -static inline int mlx_custom_vjp( - mlx_closure* res, - const mlx_closure fun, - const mlx_closure_custom fun_vjp) { - return mlx_custom_vjp_(res, fun, fun_vjp); -} -static inline int mlx_eval(const mlx_vector_array outputs) { - return mlx_eval_(outputs); -} -static inline int mlx_jvp( - mlx_vector_array* res_0, - mlx_vector_array* res_1, - const mlx_closure fun, - const mlx_vector_array primals, - const mlx_vector_array tangents) { - return mlx_jvp_(res_0, res_1, fun, primals, tangents); -} -static inline int mlx_value_and_grad( - mlx_closure_value_and_grad* res, - const mlx_closure fun, - const int* argnums, - size_t argnums_num) { - return mlx_value_and_grad_(res, fun, argnums, argnums_num); -} -static inline int mlx_vjp( - mlx_vector_array* res_0, - mlx_vector_array* res_1, - const mlx_closure fun, - const mlx_vector_array primals, - const mlx_vector_array cotangents) { - return mlx_vjp_(res_0, res_1, fun, primals, cotangents); -} + static inline int mlx_detail_vmap_replace( mlx_vector_array* res, const mlx_vector_array inputs, @@ -6673,6 +6854,7 @@ static inline int mlx_detail_vmap_replace( size_t out_axes_num) { return mlx_detail_vmap_replace_(res, inputs, s_inputs, s_outputs, in_axes, in_axes_num, out_axes, out_axes_num); } + static inline int mlx_detail_vmap_trace( mlx_vector_array* res_0, mlx_vector_array* res_1, @@ -6682,173 +6864,272 @@ static inline int mlx_detail_vmap_trace( size_t in_axes_num) { return mlx_detail_vmap_trace_(res_0, res_1, fun, inputs, in_axes, in_axes_num); } + +static inline int mlx_async_eval(const mlx_vector_array outputs) { + return mlx_async_eval_(outputs); +} + +static inline int mlx_checkpoint(mlx_closure* res, const mlx_closure fun) { + return mlx_checkpoint_(res, fun); +} + +static inline int mlx_custom_function( + mlx_closure* res, + const mlx_closure fun, + const mlx_closure_custom fun_vjp /* may be null */, + const mlx_closure_custom_jvp fun_jvp /* may be null */, + const mlx_closure_custom_vmap fun_vmap /* may be null */) { + return mlx_custom_function_(res, fun, fun_vjp, fun_jvp, fun_vmap); +} + +static inline int mlx_custom_vjp( + mlx_closure* res, + const mlx_closure fun, + const mlx_closure_custom fun_vjp) { + return mlx_custom_vjp_(res, fun, fun_vjp); +} + +static inline int mlx_eval(const mlx_vector_array outputs) { + return mlx_eval_(outputs); +} + +static inline int mlx_jvp( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + const mlx_closure fun, + const mlx_vector_array primals, + const mlx_vector_array tangents) { + return mlx_jvp_(res_0, res_1, fun, primals, tangents); +} + +static inline int mlx_value_and_grad( + mlx_closure_value_and_grad* res, + const mlx_closure fun, + const int* argnums, + size_t argnums_num) { + return mlx_value_and_grad_(res, fun, argnums, argnums_num); +} + +static inline int mlx_vjp( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + const mlx_closure fun, + const mlx_vector_array primals, + const mlx_vector_array cotangents) { + return mlx_vjp_(res_0, res_1, fun, primals, cotangents); +} + static inline mlx_vector_array mlx_vector_array_new(void) { return mlx_vector_array_new_(); } + static inline int mlx_vector_array_set(mlx_vector_array* vec, const mlx_vector_array src) { return mlx_vector_array_set_(vec, src); } + static inline int mlx_vector_array_free(mlx_vector_array vec) { return mlx_vector_array_free_(vec); } + static inline mlx_vector_array mlx_vector_array_new_data(const mlx_array* data, size_t size) { return mlx_vector_array_new_data_(data, size); } + static inline mlx_vector_array mlx_vector_array_new_value(const mlx_array val) { return mlx_vector_array_new_value_(val); } + static inline int mlx_vector_array_set_data( mlx_vector_array* vec, const mlx_array* data, size_t size) { return mlx_vector_array_set_data_(vec, data, size); } + static inline int mlx_vector_array_set_value(mlx_vector_array* vec, const mlx_array val) { return mlx_vector_array_set_value_(vec, val); } + static inline int mlx_vector_array_append_data( mlx_vector_array vec, const mlx_array* data, size_t size) { return mlx_vector_array_append_data_(vec, data, size); } + static inline int mlx_vector_array_append_value(mlx_vector_array vec, const mlx_array val) { return mlx_vector_array_append_value_(vec, val); } + static inline size_t mlx_vector_array_size(mlx_vector_array vec) { return mlx_vector_array_size_(vec); } + static inline int mlx_vector_array_get( mlx_array* res, const mlx_vector_array vec, size_t idx) { return mlx_vector_array_get_(res, vec, idx); } + static inline mlx_vector_vector_array mlx_vector_vector_array_new(void) { return mlx_vector_vector_array_new_(); } + static inline int mlx_vector_vector_array_set( mlx_vector_vector_array* vec, const mlx_vector_vector_array src) { return mlx_vector_vector_array_set_(vec, src); } + static inline int mlx_vector_vector_array_free(mlx_vector_vector_array vec) { return mlx_vector_vector_array_free_(vec); } + static inline mlx_vector_vector_array mlx_vector_vector_array_new_data( const mlx_vector_array* data, size_t size) { return mlx_vector_vector_array_new_data_(data, size); } + static inline mlx_vector_vector_array mlx_vector_vector_array_new_value( const mlx_vector_array val) { return mlx_vector_vector_array_new_value_(val); } + static inline int mlx_vector_vector_array_set_data( mlx_vector_vector_array* vec, const mlx_vector_array* data, size_t size) { return mlx_vector_vector_array_set_data_(vec, data, size); } + static inline int mlx_vector_vector_array_set_value( mlx_vector_vector_array* vec, const mlx_vector_array val) { return mlx_vector_vector_array_set_value_(vec, val); } + static inline int mlx_vector_vector_array_append_data( mlx_vector_vector_array vec, const mlx_vector_array* data, size_t size) { return mlx_vector_vector_array_append_data_(vec, data, size); } + static inline int mlx_vector_vector_array_append_value( mlx_vector_vector_array vec, const mlx_vector_array val) { return mlx_vector_vector_array_append_value_(vec, val); } + static inline size_t mlx_vector_vector_array_size(mlx_vector_vector_array vec) { return mlx_vector_vector_array_size_(vec); } + static inline int mlx_vector_vector_array_get( mlx_vector_array* res, const mlx_vector_vector_array vec, size_t idx) { return mlx_vector_vector_array_get_(res, vec, idx); } + static inline mlx_vector_int mlx_vector_int_new(void) { return mlx_vector_int_new_(); } + static inline int mlx_vector_int_set(mlx_vector_int* vec, const mlx_vector_int src) { return mlx_vector_int_set_(vec, src); } + static inline int mlx_vector_int_free(mlx_vector_int vec) { return mlx_vector_int_free_(vec); } + static inline mlx_vector_int mlx_vector_int_new_data(int* data, size_t size) { return mlx_vector_int_new_data_(data, size); } + static inline mlx_vector_int mlx_vector_int_new_value(int val) { return mlx_vector_int_new_value_(val); } + static inline int mlx_vector_int_set_data(mlx_vector_int* vec, int* data, size_t size) { return mlx_vector_int_set_data_(vec, data, size); } + static inline int mlx_vector_int_set_value(mlx_vector_int* vec, int val) { return mlx_vector_int_set_value_(vec, val); } + static inline int mlx_vector_int_append_data(mlx_vector_int vec, int* data, size_t size) { return mlx_vector_int_append_data_(vec, data, size); } + static inline int mlx_vector_int_append_value(mlx_vector_int vec, int val) { return mlx_vector_int_append_value_(vec, val); } + static inline size_t mlx_vector_int_size(mlx_vector_int vec) { return mlx_vector_int_size_(vec); } + static inline int mlx_vector_int_get(int* res, const mlx_vector_int vec, size_t idx) { return mlx_vector_int_get_(res, vec, idx); } + static inline mlx_vector_string mlx_vector_string_new(void) { return mlx_vector_string_new_(); } + static inline int mlx_vector_string_set(mlx_vector_string* vec, const mlx_vector_string src) { return mlx_vector_string_set_(vec, src); } + static inline int mlx_vector_string_free(mlx_vector_string vec) { return mlx_vector_string_free_(vec); } + static inline mlx_vector_string mlx_vector_string_new_data(const char** data, size_t size) { return mlx_vector_string_new_data_(data, size); } + static inline mlx_vector_string mlx_vector_string_new_value(const char* val) { return mlx_vector_string_new_value_(val); } + static inline int mlx_vector_string_set_data( mlx_vector_string* vec, const char** data, size_t size) { return mlx_vector_string_set_data_(vec, data, size); } + static inline int mlx_vector_string_set_value(mlx_vector_string* vec, const char* val) { return mlx_vector_string_set_value_(vec, val); } + static inline int mlx_vector_string_append_data( mlx_vector_string vec, const char** data, size_t size) { return mlx_vector_string_append_data_(vec, data, size); } + static inline int mlx_vector_string_append_value(mlx_vector_string vec, const char* val) { return mlx_vector_string_append_value_(vec, val); } + static inline size_t mlx_vector_string_size(mlx_vector_string vec) { return mlx_vector_string_size_(vec); } + static inline int mlx_vector_string_get(char** res, const mlx_vector_string vec, size_t idx) { return mlx_vector_string_get_(res, vec, idx); } + static inline int mlx_version(mlx_string* str_) { return mlx_version_(str_); } -#endif // MLX_GENERATED_H +#endif // MLX_GENERATED_H \ No newline at end of file diff --git a/x/mlxrunner/mlx/generator/generated.h.gotmpl b/x/mlxrunner/mlx/generator/generated.h.gotmpl index 594a3f3e3..8f043573b 100644 --- a/x/mlxrunner/mlx/generator/generated.h.gotmpl +++ b/x/mlxrunner/mlx/generator/generated.h.gotmpl @@ -4,10 +4,6 @@ #define MLX_GENERATED_H #include "dynamic.h" -{{ range .Functions }} -#define {{ .Name }} {{ .Name }}_mlx_gen_orig_ -{{- end }} - #include "mlx/c/mlx.h" {{ range .Functions }} #undef {{ .Name }}