mirror of
https://github.com/firestar5683/StarPilot.git
synced 2026-06-27 17:42:04 +08:00
cython wrapper for acados (#22784)
* cython wrapper for acados * fix building * sconscript cleanup * no cython numpy * cleanup * upgrade build script * try without slices * new acados commit * c3 update acados libs * c2 libs * make faster * undo profiling * fix build * somewhat faster * tryout cost_set_slice * Revert "tryout cost_set_slice" This reverts commit d358d93a133270e4edab9e7c07ffb6f577c52bd6. * cleanup * undo t_renderer change Co-authored-by: Comma Device <device@comma.ai> old-commit-hash: 89d0a52d16872c403c69426ab32e5788a41ee2ec
This commit is contained in:
@@ -68,6 +68,7 @@ lenv = {
|
||||
"PYTHONPATH": Dir("#").abspath + ":" + Dir("#pyextra/").abspath,
|
||||
|
||||
"ACADOS_SOURCE_DIR": Dir("#third_party/acados/acados").abspath,
|
||||
"ACADOS_PYTHON_INTERFACE_PATH": Dir("#pyextra/acados_template").abspath,
|
||||
"TERA_PATH": Dir("#").abspath + f"/third_party/acados/{arch}/t_renderer",
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:32dae3052f331ee34d628ef535709b301259a45df7c7522c4d35dcf49873f00b
|
||||
size 13
|
||||
oid sha256:477f73573a50b1ae2740849e1aed4f8d353ead59116f95b897e8f620e8dbf31b
|
||||
size 62
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:23cbba2db6b7c99e303802e65d4714a24f7b943d32aae0feb80f596b821ea13b
|
||||
size 13664
|
||||
oid sha256:81426176ea04cb7ed7b4b17c2941f49dcc8c2d576c2fc78ff158a6e9ec2839e5
|
||||
size 13723
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9b8238f6510a691d9b686eeba5612ac3c20fcc434f19f88f6c1039b821621fd2
|
||||
size 7254
|
||||
oid sha256:a27c58335c2ffae67468f2cbac2ecfd59300d8af5a97b15cc29273dff20a0aef
|
||||
size 8020
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d02c555bdde7266d683d3c2d6fa9812a86b7e3024f841211158197fca94d1f08
|
||||
size 95132
|
||||
oid sha256:7bbb8ea1f67d4b2247be1eb7e4e6c2d5cc89fac823d0ca4337804a8c0f8de387
|
||||
size 95897
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b406ef368856f2e70e02a56b42361d7c5937383529654489318e5831a1d23926
|
||||
size 60258
|
||||
oid sha256:4463ca9b30ad759498497085bfb94e4d11994dfb88e915aa209c124da13254c1
|
||||
size 70093
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4ec5899c033238f4815409cabdb109b5e7b833db9018a54345e927885bf12d1a
|
||||
size 17627
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1ab310b29ad2be1b0216b01db7d340e501d29add43cf786a2ef0917f95d6dff7
|
||||
size 10392
|
||||
oid sha256:534e0fcc2b75f6217fb748e9c8a1c365ad2a3b201bcba895484995fb5ed78133
|
||||
size 11803
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f56b25cddcdd2043d91a089e3f0ddf670f46a24cd5badd7ac03cc59ffe85caa2
|
||||
size 700
|
||||
oid sha256:7e21d44ed88590aaa52541839e2e8a8d008ceb5d82c7e2af9736df7b0f4eb5e9
|
||||
size 759
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e784ba1029a75e244a5671a57b8b28ffd569099a61ad707ac86cea03fa637396
|
||||
size 16349
|
||||
oid sha256:8d118afbf4993842e202870ff43675d6837befa3a552ddace49b334e6715bb72
|
||||
size 16342
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:dd53981de52f81cccaf7b35054bde06dba358c3425f792877fa837e9e74ecf1a
|
||||
size 4216
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:22950565d08e2f1c34dd622800279d21af70266de45ecb5f07909d87caf636d1
|
||||
size 18458
|
||||
oid sha256:ae30257d087cf8a0e41a9c9d6c41bb9bc1700059678bc7e82cf899f32f26c71a
|
||||
size 19209
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:aba81942a0b7cf9a2b14f8b18fbde557dc974358de5989ff4fd3c8725d429056
|
||||
size 17368
|
||||
oid sha256:b69ab4513e92ed696e708028a1862fae9f6f38f6c0c3de284f198e5e189fb9b0
|
||||
size 17381
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:01c5708e3cad58a5448ea429000b0110c1957d9d254f822d9e0ecc058533eda5
|
||||
size 2529
|
||||
oid sha256:cfab0ed0fac52bea3bb4dae81dd45f18764330f56d08dfed5c80d11923ad6b44
|
||||
size 2555
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6bdee4b81723dd296c353d1ea4b2cf39d5247912f9d06fc98c075831cb060b06
|
||||
size 19886
|
||||
oid sha256:7cfae123bb078e0637621b339a14d0b776f9c451ba50939ed9ec341a8daa030e
|
||||
size 19912
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:65a7ed464643a436c1baf426458674fdd49ef77952b1a63da221b1b08c72ac4c
|
||||
size 2180
|
||||
oid sha256:1b5a8e5e2f2bf2456962c53842de00d236d67000b98cb61e48c968fe760eb6a5
|
||||
size 2206
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:504de8ac91b2dc75dcb01aa6894b8435e57db03cf387c1e1f48d7d03b2f737d6
|
||||
size 26492
|
||||
oid sha256:d1ed4a97c47e85468b4a19bb0b0dae365f5b13bff7054a58e547a2e530570ee0
|
||||
size 25463
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:690484ab973bf37151fd437c65e3f708fba8c6a4fe6680054acc5ff3e4aa36c5
|
||||
size 3976
|
||||
oid sha256:6d12f0cbf3624a4dbf05cd50fadcbf895507d84f869c741bc1c2876ea7bdf18f
|
||||
size 4193
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ddeef3b42e38682017002ed2e176dc03edba8292464ae1ba44c7cfdd59897f6a
|
||||
size 112677
|
||||
oid sha256:46068c324d8a79f92eb0b758067e68708829fcd31b6f6106fdd57ddaa3ba4e00
|
||||
size 115538
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bce8b99e1661834d1e20777634844ecc31c6953a1bd58b52ad9dbe67ef23a656
|
||||
size 6135
|
||||
oid sha256:4d0202112aabfba5b642e7bab59c4608d3227981ba22509166f9ce3c55a3a39b
|
||||
size 9834
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2d92fa9835ac4149aa8ad7a4a29a8b4c1a6234c22d43c31df313f94e2602f9e9
|
||||
size 1708
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7ad7b08956c4980fa57f4386ddf4ded811049adc2690933984cc0c80b842b9aa
|
||||
size 28517
|
||||
oid sha256:6395761be4c3e202197dd6d1953d52299260f2ff6a9fd13bf122a752fdc2467b
|
||||
size 28556
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1539ee98928959ec0a4b50ee33a25f5a18028a903be0e5bd09e4247e2cadc7f4
|
||||
size 6639
|
||||
oid sha256:c30036c8497a66d000b05c7dcf8e7f37a3fed477885cf5d9d4600eacbc6cf96d
|
||||
size 7983
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ef55fa190b17ab806cb233c53ab041f0f6b4e99d535bcffd0e416946cd1bcad2
|
||||
size 4557
|
||||
oid sha256:659efc1476d5b27bb6ade88223762fe65d2aa1e96bae0c1fdd7dec6e167fe02a
|
||||
size 4677
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e97ba33997f96d57bc70970548f79fc83fee147668cd75e4cefd097d37eda117
|
||||
size 6716
|
||||
oid sha256:e3ca280cc474c3fa50496ca7e21d3b7c918843c6c4f3737627538eb3a8847148
|
||||
size 6728
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ba36522a9299fe913e9dcc4890adc6fd74e9c33b6234c4b5468ebedccc1a8096
|
||||
size 3665
|
||||
oid sha256:25cbd49d4188265253ff16f537bb13b64b2d20957416009b8bf09704952dd318
|
||||
size 3677
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7523392004ef1f5f710ded1ccd415d9e5b6d0601e125beedd7968e333141a3a2
|
||||
size 4406
|
||||
oid sha256:148abc2a03f578abaea318edd1042c01d3bb04fc6d9e33b7cab0bab48d9ca53f
|
||||
size 4418
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a1c794693e86d1b5f70674c4fd78575409ee34a33520ca9ab295cbfe9d0818c3
|
||||
size 4091
|
||||
oid sha256:df61d4a3b7c0e04bb02dab5408be15db88e293b481301febb129aedd12cd0cab
|
||||
size 4337
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:68abbd3ba7391d5bc617eef0a6a47c5f9b579de77307cc3b5ec40de01d85a08d
|
||||
size 4923
|
||||
oid sha256:8ad9558c20c7bf4a526874c51204a86cc41b884d071ff55fab9c778a6f2ef999
|
||||
size 4935
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8c7bc6cc1f33a6c5c04fe5012284486126f59e3b28e8f518ea3691fc1f68ff60
|
||||
size 5189
|
||||
oid sha256:14938e8888f7fc72cbf0d1c8614d39952be8909aeb707975afb9b03e2aa817eb
|
||||
size 5201
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6a2fd8843020aee038722f9028ccbfb8b7147cec8f23f4a4249f6d262b741e09
|
||||
size 3852
|
||||
oid sha256:b76eeb34b5ff74e072d53cbd08e464ebab8361e805aa6063b23de821ea400287
|
||||
size 3864
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:de3a26bc1112e1fff37bebb2b0736fb88bc08b9aab182307acaa209161fd174d
|
||||
size 15426
|
||||
oid sha256:3172aa8e2130b7d411bc330e0b7b3038ee0896db22bcdee9ef045392f03ea6ce
|
||||
size 15924
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
Import('env', 'arch')
|
||||
Import('env', 'envCython', 'arch', 'common')
|
||||
|
||||
gen = "c_generated_code"
|
||||
|
||||
@@ -33,6 +33,7 @@ generated_files = [
|
||||
|
||||
f'{gen}/main_lat.c',
|
||||
f'{gen}/acados_solver_lat.h',
|
||||
f'{gen}/acados_solver.pxd',
|
||||
|
||||
f'{gen}/lat_model/lat_expl_vde_adj.c',
|
||||
|
||||
@@ -53,6 +54,24 @@ lenv["CFLAGS"].append("-DACADOS_WITH_QPOASES")
|
||||
lenv["CXXFLAGS"].append("-DACADOS_WITH_QPOASES")
|
||||
lenv["CCFLAGS"].append("-Wno-unused")
|
||||
lenv["LINKFLAGS"].append("-Wl,--disable-new-dtags")
|
||||
lenv.SharedLibrary(f"{gen}/acados_ocp_solver_lat",
|
||||
build_files,
|
||||
LIBS=['m', 'acados', 'hpipm', 'blasfeo', 'qpOASES_e'])
|
||||
lib_solver = lenv.SharedLibrary(f"{gen}/acados_ocp_solver_lat",
|
||||
build_files,
|
||||
LIBS=['m', 'acados', 'hpipm', 'blasfeo', 'qpOASES_e'])
|
||||
|
||||
# generate cython stuff
|
||||
acados_ocp_solver_pyx = File("#pyextra/acados_template/acados_ocp_solver_pyx.pyx")
|
||||
acados_ocp_solver_common = File("#pyextra/acados_template/acados_solver_common.pxd")
|
||||
libacados_ocp_solver_pxd = File(f'{gen}/acados_solver.pxd')
|
||||
libacados_ocp_solver_c = File(f'{gen}/acados_ocp_solver_pyx.c')
|
||||
|
||||
lenv2 = envCython.Clone()
|
||||
lenv2["LINKFLAGS"] += [lib_solver[0].get_labspath()]
|
||||
lenv2.Command(libacados_ocp_solver_c,
|
||||
[acados_ocp_solver_pyx, acados_ocp_solver_common, libacados_ocp_solver_pxd],
|
||||
f'cython' + \
|
||||
f' -o {libacados_ocp_solver_c.get_labspath()}' + \
|
||||
f' -I {libacados_ocp_solver_pxd.get_dir().get_labspath()}' + \
|
||||
f' -I {acados_ocp_solver_common.get_dir().get_labspath()}' + \
|
||||
f' {acados_ocp_solver_pyx.get_labspath()}')
|
||||
lib_cython = lenv2.Program(f'{gen}/acados_ocp_solver_pyx.so', [libacados_ocp_solver_c])
|
||||
lenv2.Depends(lib_cython, lib_solver)
|
||||
|
||||
@@ -5,8 +5,12 @@ import numpy as np
|
||||
from casadi import SX, vertcat, sin, cos
|
||||
from selfdrive.controls.lib.drive_helpers import LAT_MPC_N as N
|
||||
from selfdrive.controls.lib.drive_helpers import T_IDXS
|
||||
from pyextra.acados_template import AcadosModel, AcadosOcp, AcadosOcpSolver
|
||||
|
||||
if __name__ == '__main__': # generating code
|
||||
from pyextra.acados_template import AcadosModel, AcadosOcp, AcadosOcpSolver
|
||||
else:
|
||||
# from pyextra.acados_template import AcadosOcpSolverFast
|
||||
from selfdrive.controls.lib.lateral_mpc_lib.c_generated_code.acados_ocp_solver_pyx import AcadosOcpSolverFast # pylint: disable=no-name-in-module, import-error
|
||||
|
||||
LAT_MPC_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
EXPORT_DIR = os.path.join(LAT_MPC_DIR, "c_generated_code")
|
||||
@@ -110,17 +114,16 @@ def gen_lat_mpc_solver():
|
||||
|
||||
class LateralMpc():
|
||||
def __init__(self, x0=np.zeros(X_DIM)):
|
||||
self.solver = AcadosOcpSolver('lat', N, EXPORT_DIR)
|
||||
self.solver = AcadosOcpSolverFast('lat', N, EXPORT_DIR)
|
||||
self.reset(x0)
|
||||
|
||||
def reset(self, x0=np.zeros(X_DIM)):
|
||||
self.x_sol = np.zeros((N+1, X_DIM))
|
||||
self.u_sol = np.zeros((N, 1))
|
||||
self.yref = np.zeros((N+1, 3))
|
||||
self.solver.cost_set_slice(0, N, "yref", self.yref[:N])
|
||||
for i in range(N):
|
||||
self.solver.cost_set(i, "yref", self.yref[i])
|
||||
self.solver.cost_set(N, "yref", self.yref[N][:2])
|
||||
W = np.eye(3)
|
||||
self.Ws = np.tile(W[None], reps=(N,1,1))
|
||||
|
||||
# Somehow needed for stable init
|
||||
for i in range(N+1):
|
||||
@@ -132,12 +135,11 @@ class LateralMpc():
|
||||
self.cost = 0
|
||||
|
||||
def set_weights(self, path_weight, heading_weight, steer_rate_weight):
|
||||
self.Ws[:,0,0] = path_weight
|
||||
self.Ws[:,1,1] = heading_weight
|
||||
self.Ws[:,2,2] = steer_rate_weight
|
||||
self.solver.cost_set_slice(0, N, 'W', self.Ws, api='old')
|
||||
W = np.asfortranarray(np.diag([path_weight, heading_weight, steer_rate_weight]))
|
||||
for i in range(N):
|
||||
self.solver.cost_set(i, 'W', W)
|
||||
#TODO hacky weights to keep behavior the same
|
||||
self.solver.cost_set(N, 'W', (3/20.)*self.Ws[0,:2,:2])
|
||||
self.solver.cost_set(N, 'W', (3/20.)*W[:2,:2])
|
||||
|
||||
def run(self, x0, v_ego, car_rotation_radius, y_pts, heading_pts):
|
||||
x0_cp = np.copy(x0)
|
||||
@@ -145,12 +147,15 @@ class LateralMpc():
|
||||
self.solver.constraints_set(0, "ubx", x0_cp)
|
||||
self.yref[:,0] = y_pts
|
||||
self.yref[:,1] = heading_pts*(v_ego+5.0)
|
||||
self.solver.cost_set_slice(0, N, "yref", self.yref[:N])
|
||||
for i in range(N):
|
||||
self.solver.cost_set(i, "yref", self.yref[i])
|
||||
self.solver.cost_set(N, "yref", self.yref[N][:2])
|
||||
|
||||
self.solution_status = self.solver.solve()
|
||||
self.solver.fill_in_slice(0, N+1, 'x', self.x_sol)
|
||||
self.solver.fill_in_slice(0, N, 'u', self.u_sol)
|
||||
for i in range(N+1):
|
||||
self.x_sol[i] = self.solver.get(i, 'x')
|
||||
for i in range(N):
|
||||
self.u_sol[i] = self.solver.get(i, 'u')
|
||||
self.cost = self.solver.get_cost()
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
Import('env', 'arch')
|
||||
Import('env', 'envCython', 'arch', 'common')
|
||||
|
||||
gen = "c_generated_code"
|
||||
|
||||
@@ -41,6 +41,7 @@ generated_files = [
|
||||
|
||||
f'{gen}/main_long.c',
|
||||
f'{gen}/acados_solver_long.h',
|
||||
f'{gen}/acados_solver.pxd',
|
||||
|
||||
f'{gen}/long_model/long_expl_vde_adj.c',
|
||||
|
||||
@@ -63,6 +64,24 @@ lenv["CFLAGS"].append("-DACADOS_WITH_QPOASES")
|
||||
lenv["CXXFLAGS"].append("-DACADOS_WITH_QPOASES")
|
||||
lenv["CCFLAGS"].append("-Wno-unused")
|
||||
lenv["LINKFLAGS"].append("-Wl,--disable-new-dtags")
|
||||
lenv.SharedLibrary(f"{gen}/acados_ocp_solver_long",
|
||||
build_files,
|
||||
LIBS=['m', 'acados', 'hpipm', 'blasfeo', 'qpOASES_e'])
|
||||
lib_solver = lenv.SharedLibrary(f"{gen}/acados_ocp_solver_long",
|
||||
build_files,
|
||||
LIBS=['m', 'acados', 'hpipm', 'blasfeo', 'qpOASES_e'])
|
||||
|
||||
# generate cython stuff
|
||||
acados_ocp_solver_pyx = File("#pyextra/acados_template/acados_ocp_solver_pyx.pyx")
|
||||
acados_ocp_solver_common = File("#pyextra/acados_template/acados_solver_common.pxd")
|
||||
libacados_ocp_solver_pxd = File(f'{gen}/acados_solver.pxd')
|
||||
libacados_ocp_solver_c = File(f'{gen}/acados_ocp_solver_pyx.c')
|
||||
|
||||
lenv2 = envCython.Clone()
|
||||
lenv2["LINKFLAGS"] += [lib_solver[0].get_labspath()]
|
||||
lenv2.Command(libacados_ocp_solver_c,
|
||||
[acados_ocp_solver_pyx, acados_ocp_solver_common, libacados_ocp_solver_pxd],
|
||||
f'cython' + \
|
||||
f' -o {libacados_ocp_solver_c.get_labspath()}' + \
|
||||
f' -I {libacados_ocp_solver_pxd.get_dir().get_labspath()}' + \
|
||||
f' -I {acados_ocp_solver_common.get_dir().get_labspath()}' + \
|
||||
f' {acados_ocp_solver_pyx.get_labspath()}')
|
||||
lib_cython = lenv2.Program(f'{gen}/acados_ocp_solver_pyx.so', [libacados_ocp_solver_c])
|
||||
lenv2.Depends(lib_cython, lib_solver)
|
||||
|
||||
@@ -8,7 +8,12 @@ from selfdrive.swaglog import cloudlog
|
||||
from selfdrive.modeld.constants import index_function
|
||||
from selfdrive.controls.lib.radar_helpers import _LEAD_ACCEL_TAU
|
||||
|
||||
from pyextra.acados_template import AcadosModel, AcadosOcp, AcadosOcpSolver
|
||||
if __name__ == '__main__': # generating code
|
||||
from pyextra.acados_template import AcadosModel, AcadosOcp, AcadosOcpSolver
|
||||
else:
|
||||
# from pyextra.acados_template import AcadosOcpSolver as AcadosOcpSolverFast
|
||||
from selfdrive.controls.lib.longitudinal_mpc_lib.c_generated_code.acados_ocp_solver_pyx import AcadosOcpSolverFast # pylint: disable=no-name-in-module, import-error
|
||||
|
||||
from casadi import SX, vertcat
|
||||
|
||||
LONG_MPC_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -190,13 +195,14 @@ class LongitudinalMpc():
|
||||
self.source = SOURCES[2]
|
||||
|
||||
def reset(self):
|
||||
self.solver = AcadosOcpSolver('long', N, EXPORT_DIR)
|
||||
self.solver = AcadosOcpSolverFast('long', N, EXPORT_DIR)
|
||||
self.v_solution = [0.0 for i in range(N+1)]
|
||||
self.a_solution = [0.0 for i in range(N+1)]
|
||||
self.j_solution = [0.0 for i in range(N)]
|
||||
self.yref = np.zeros((N+1, COST_DIM))
|
||||
self.solver.cost_set_slice(0, N, "yref", self.yref[:N])
|
||||
self.solver.set(N, "yref", self.yref[N][:COST_E_DIM])
|
||||
for i in range(N):
|
||||
self.solver.cost_set(i, "yref", self.yref[i])
|
||||
self.solver.cost_set(N, "yref", self.yref[N][:COST_E_DIM])
|
||||
self.x_sol = np.zeros((N+1, X_DIM))
|
||||
self.u_sol = np.zeros((N,1))
|
||||
self.params = np.zeros((N+1,3))
|
||||
@@ -216,30 +222,30 @@ class LongitudinalMpc():
|
||||
self.set_weights_for_lead_policy()
|
||||
|
||||
def set_weights_for_lead_policy(self):
|
||||
W = np.diag([X_EGO_OBSTACLE_COST, X_EGO_COST, V_EGO_COST, A_EGO_COST, J_EGO_COST])
|
||||
Ws = np.tile(W[None], reps=(N,1,1))
|
||||
self.solver.cost_set_slice(0, N, 'W', Ws, api='old')
|
||||
W = np.asfortranarray(np.diag([X_EGO_OBSTACLE_COST, X_EGO_COST, V_EGO_COST, A_EGO_COST, J_EGO_COST]))
|
||||
for i in range(N):
|
||||
self.solver.cost_set(i, 'W', W)
|
||||
# Setting the slice without the copy make the array not contiguous,
|
||||
# causing issues with the C interface.
|
||||
self.solver.cost_set(N, 'W', np.copy(W[:COST_E_DIM, :COST_E_DIM]))
|
||||
|
||||
# Set L2 slack cost on lower bound constraints
|
||||
Zl = np.array([LIMIT_COST, LIMIT_COST, LIMIT_COST, DANGER_ZONE_COST])
|
||||
Zls = np.tile(Zl[None], reps=(N+1,1,1))
|
||||
self.solver.cost_set_slice(0, N+1, 'Zl', Zls, api='old')
|
||||
for i in range(N):
|
||||
self.solver.cost_set(i, 'Zl', Zl)
|
||||
|
||||
def set_weights_for_xva_policy(self):
|
||||
W = np.diag([0., 10., 1., 10., 1.])
|
||||
Ws = np.tile(W[None], reps=(N,1,1))
|
||||
self.solver.cost_set_slice(0, N, 'W', Ws, api='old')
|
||||
W = np.asfortranarray(np.diag([0., 10., 1., 10., 1.]))
|
||||
for i in range(N):
|
||||
self.solver.cost_set(i, 'W', W)
|
||||
# Setting the slice without the copy make the array not contiguous,
|
||||
# causing issues with the C interface.
|
||||
self.solver.cost_set(N, 'W', np.copy(W[:COST_E_DIM, :COST_E_DIM]))
|
||||
|
||||
# Set L2 slack cost on lower bound constraints
|
||||
Zl = np.array([LIMIT_COST, LIMIT_COST, LIMIT_COST, 0.0])
|
||||
Zls = np.tile(Zl[None], reps=(N+1,1,1))
|
||||
self.solver.cost_set_slice(0, N+1, 'Zl', Zls, api='old')
|
||||
for i in range(N):
|
||||
self.solver.cost_set(i, 'Zl', Zl)
|
||||
|
||||
def set_cur_state(self, v, a):
|
||||
if abs(self.x0[1] - v) > 1.:
|
||||
@@ -326,8 +332,9 @@ class LongitudinalMpc():
|
||||
self.yref[:,1] = x
|
||||
self.yref[:,2] = v
|
||||
self.yref[:,3] = a
|
||||
self.solver.cost_set_slice(0, N, "yref", self.yref[:N], api='old')
|
||||
self.solver.set(N, "yref", self.yref[N][:COST_E_DIM])
|
||||
for i in range(N):
|
||||
self.solver.cost_set(i, "yref", self.yref[i])
|
||||
self.solver.cost_set(N, "yref", self.yref[N][:COST_E_DIM])
|
||||
self.accel_limit_arr[:,0] = -10.
|
||||
self.accel_limit_arr[:,1] = 10.
|
||||
x_obstacle = 1e5*np.ones((N+1))
|
||||
@@ -338,12 +345,14 @@ class LongitudinalMpc():
|
||||
|
||||
def run(self):
|
||||
for i in range(N+1):
|
||||
self.solver.set_param(i, self.params[i])
|
||||
self.solver.set(i, 'p', self.params[i])
|
||||
self.solver.constraints_set(0, "lbx", self.x0)
|
||||
self.solver.constraints_set(0, "ubx", self.x0)
|
||||
self.solution_status = self.solver.solve()
|
||||
self.solver.fill_in_slice(0, N+1, 'x', self.x_sol)
|
||||
self.solver.fill_in_slice(0, N, 'u', self.u_sol)
|
||||
for i in range(N+1):
|
||||
self.x_sol[i] = self.solver.get(i, 'x')
|
||||
for i in range(N):
|
||||
self.u_sol[i] = self.solver.get(i, 'u')
|
||||
|
||||
self.v_solution = self.x_sol[:,1]
|
||||
self.a_solution = self.x_sol[:,2]
|
||||
|
||||
+2
-2
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c86a0f145884b913abdce6b9b57953cc462e1ef4b0749b7a8e52878b41240281
|
||||
size 567753
|
||||
oid sha256:a42119de38b56c672e473edd8860e2e75c3e2793961b631634f8dc6f6f6618e5
|
||||
size 575769
|
||||
|
||||
+1
-1
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f459390b07bd4cff5059475e70436a4843469aecf82cb4137bf77a86732648de
|
||||
oid sha256:567bfcb693ce73c8faa847ddd3798037f611c8ebfcecddce4da2140f22b191af
|
||||
size 694193
|
||||
|
||||
+1
-1
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a363f66e7a0c7b3ce5e4bd202e0320431253f9d3d9f6d6423b031afba4c8017b
|
||||
oid sha256:d6873d1bf018369aaf13cabe5d870d3154656668b31647df0a1a648b3bb383dc
|
||||
size 1324409
|
||||
|
||||
Vendored
+1
-1
@@ -18,7 +18,7 @@ if [ ! -d acados_repo/ ]; then
|
||||
fi
|
||||
cd acados_repo
|
||||
git fetch
|
||||
git checkout 43ba28e95062f9ac9b48facd3b45698d57666fa3
|
||||
git checkout 79e9e3e76f2751198858adf382c97837833ad31f
|
||||
git submodule update --recursive --init
|
||||
|
||||
# build
|
||||
|
||||
@@ -135,6 +135,8 @@ typedef struct ocp_nlp_dims
|
||||
int *nz; // number of algebraic variables
|
||||
int *ns; // number of slack variables
|
||||
int N; // number of shooting nodes
|
||||
|
||||
void *raw_memory; // Pointer to allocated memory, to be used for freeing
|
||||
} ocp_nlp_dims;
|
||||
|
||||
//
|
||||
@@ -203,6 +205,9 @@ typedef struct ocp_nlp_in
|
||||
/// Pointers to constraints functions (TBC).
|
||||
void **constraints;
|
||||
|
||||
/// Pointer to allocated memory, to be used for freeing.
|
||||
void *raw_memory;
|
||||
|
||||
} ocp_nlp_in;
|
||||
|
||||
//
|
||||
@@ -235,6 +240,8 @@ typedef struct ocp_nlp_out
|
||||
double inf_norm_res;
|
||||
double total_time;
|
||||
|
||||
void *raw_memory; // Pointer to allocated memory, to be used for freeing
|
||||
|
||||
} ocp_nlp_out;
|
||||
|
||||
//
|
||||
|
||||
+38
-29
@@ -43,44 +43,53 @@ extern "C" {
|
||||
|
||||
|
||||
|
||||
enum Newton_type_collocation
|
||||
// enum Newton_type_collocation
|
||||
// {
|
||||
// exact = 0,
|
||||
// simplified_in,
|
||||
// simplified_inis
|
||||
// };
|
||||
|
||||
|
||||
|
||||
// typedef struct
|
||||
// {
|
||||
// enum Newton_type_collocation type;
|
||||
// double *eig;
|
||||
// double *low_tria;
|
||||
// bool single;
|
||||
// bool freeze;
|
||||
|
||||
// double *transf1;
|
||||
// double *transf2;
|
||||
|
||||
// double *transf1_T;
|
||||
// double *transf2_T;
|
||||
// } Newton_scheme;
|
||||
|
||||
|
||||
typedef enum
|
||||
{
|
||||
exact = 0,
|
||||
simplified_in,
|
||||
simplified_inis
|
||||
};
|
||||
|
||||
|
||||
|
||||
typedef struct
|
||||
{
|
||||
enum Newton_type_collocation type;
|
||||
double *eig;
|
||||
double *low_tria;
|
||||
bool single;
|
||||
bool freeze;
|
||||
|
||||
double *transf1;
|
||||
double *transf2;
|
||||
|
||||
double *transf1_T;
|
||||
double *transf2_T;
|
||||
} Newton_scheme;
|
||||
|
||||
GAUSS_LEGENDRE,
|
||||
GAUSS_RADAU_IIA,
|
||||
} sim_collocation_type;
|
||||
|
||||
|
||||
//
|
||||
acados_size_t gauss_nodes_work_calculate_size(int ns);
|
||||
// acados_size_t gauss_legendre_nodes_work_calculate_size(int ns);
|
||||
//
|
||||
void gauss_nodes(int ns, double *nodes, void *raw_memory);
|
||||
// void gauss_legendre_nodes(int ns, double *nodes, void *raw_memory);
|
||||
//
|
||||
acados_size_t gauss_simplified_work_calculate_size(int ns);
|
||||
// acados_size_t gauss_simplified_work_calculate_size(int ns);
|
||||
// //
|
||||
// void gauss_simplified(int ns, Newton_scheme *scheme, void *work);
|
||||
//
|
||||
void gauss_simplified(int ns, Newton_scheme *scheme, void *work);
|
||||
acados_size_t butcher_tableau_work_calculate_size(int ns);
|
||||
//
|
||||
acados_size_t butcher_table_work_calculate_size(int ns);
|
||||
// void calculate_butcher_tableau_from_nodes(int ns, double *nodes, double *b, double *A, void *work);
|
||||
//
|
||||
void butcher_table(int ns, double *nodes, double *b, double *A, void *work);
|
||||
void calculate_butcher_tableau(int ns, sim_collocation_type collocation_type, double *c_vec,
|
||||
double *b_vec, double *A_mat, void *work);
|
||||
|
||||
|
||||
|
||||
|
||||
+2
-1
@@ -141,12 +141,13 @@ typedef struct
|
||||
bool output_z; // 1 -- if zn should be computed
|
||||
bool sens_algebraic; // 1 -- if S_algebraic should be computed
|
||||
bool exact_z_output; // 1 -- if z, S_algebraic should be computed exactly, extra Newton iterations
|
||||
sim_collocation_type collocation_type;
|
||||
|
||||
// for explicit integrators: newton_iter == 0 && scheme == NULL
|
||||
// && jac_reuse=false
|
||||
int newton_iter;
|
||||
bool jac_reuse;
|
||||
Newton_scheme *scheme;
|
||||
// Newton_scheme *scheme;
|
||||
|
||||
// workspace
|
||||
void *work;
|
||||
|
||||
+1
-1
@@ -40,7 +40,7 @@ extern "C" {
|
||||
|
||||
#include "acados/utils/types.h"
|
||||
|
||||
#if defined(__DSPACE__)
|
||||
#if defined(__MABX2__)
|
||||
double fmax(double a, double b);
|
||||
#endif
|
||||
|
||||
|
||||
+1
-1
@@ -67,7 +67,7 @@ typedef struct acados_timer_
|
||||
mach_timebase_info_data_t tinfo;
|
||||
} acados_timer;
|
||||
|
||||
#elif defined(__DSPACE__)
|
||||
#elif defined(__MABX2__)
|
||||
|
||||
#include <brtenv.h>
|
||||
|
||||
|
||||
+5
-8
@@ -227,10 +227,8 @@ int ocp_nlp_dynamics_model_set(ocp_nlp_config *config, ocp_nlp_dims *dims, ocp_n
|
||||
/// \param field The name of the field, either nls_res_jac,
|
||||
/// y_ref, W (others TBC)
|
||||
/// \param value Cost values.
|
||||
int ocp_nlp_cost_model_set_slice(ocp_nlp_config *config, ocp_nlp_dims *dims, ocp_nlp_in *in,
|
||||
int start_stage, int end_stage, const char *field, void *value, int dim);
|
||||
int ocp_nlp_cost_model_set(ocp_nlp_config *config, ocp_nlp_dims *dims, ocp_nlp_in *in,
|
||||
int start_stage, const char *field, void *value);
|
||||
int stage, const char *field, void *value);
|
||||
|
||||
|
||||
/// Sets the function pointers to the constraints functions for the given stage.
|
||||
@@ -241,8 +239,6 @@ int ocp_nlp_cost_model_set(ocp_nlp_config *config, ocp_nlp_dims *dims, ocp_nlp_i
|
||||
/// \param stage Stage number.
|
||||
/// \param field The name of the field, either lb, ub (others TBC)
|
||||
/// \param value Constraints function or values.
|
||||
int ocp_nlp_constraints_model_set_slice(ocp_nlp_config *config, ocp_nlp_dims *dims, ocp_nlp_in *in,
|
||||
int start_stage, int end_stage, const char *field, void *value, int dim);
|
||||
int ocp_nlp_constraints_model_set(ocp_nlp_config *config, ocp_nlp_dims *dims,
|
||||
ocp_nlp_in *in, int stage, const char *field, void *value);
|
||||
|
||||
@@ -283,9 +279,6 @@ void ocp_nlp_out_set(ocp_nlp_config *config, ocp_nlp_dims *dims, ocp_nlp_out *ou
|
||||
void ocp_nlp_out_get(ocp_nlp_config *config, ocp_nlp_dims *dims, ocp_nlp_out *out,
|
||||
int stage, const char *field, void *value);
|
||||
|
||||
void ocp_nlp_out_get_slice(ocp_nlp_config *config, ocp_nlp_dims *dims, ocp_nlp_out *out,
|
||||
int start_stage, int end_stage, const char *field, void *value);
|
||||
|
||||
//
|
||||
void ocp_nlp_get_at_stage(ocp_nlp_config *config, ocp_nlp_dims *dims, ocp_nlp_solver *solver,
|
||||
int stage, const char *field, void *value);
|
||||
@@ -300,6 +293,8 @@ void ocp_nlp_constraint_dims_get_from_attr(ocp_nlp_config *config, ocp_nlp_dims
|
||||
void ocp_nlp_cost_dims_get_from_attr(ocp_nlp_config *config, ocp_nlp_dims *dims, ocp_nlp_out *out,
|
||||
int stage, const char *field, int *dims_out);
|
||||
|
||||
void ocp_nlp_dynamics_dims_get_from_attr(ocp_nlp_config *config, ocp_nlp_dims *dims, ocp_nlp_out *out,
|
||||
int stage, const char *field, int *dims_out);
|
||||
|
||||
/* opts */
|
||||
|
||||
@@ -374,6 +369,8 @@ int ocp_nlp_precompute(ocp_nlp_solver *solver, ocp_nlp_in *nlp_in, ocp_nlp_out *
|
||||
/// \param nlp_out The output struct.
|
||||
void ocp_nlp_eval_cost(ocp_nlp_solver *solver, ocp_nlp_in *nlp_in, ocp_nlp_out *nlp_out);
|
||||
|
||||
//
|
||||
void ocp_nlp_eval_residuals(ocp_nlp_solver *solver, ocp_nlp_in *nlp_in, ocp_nlp_out *nlp_out);
|
||||
|
||||
//
|
||||
void ocp_nlp_eval_param_sens(ocp_nlp_solver *solver, char *field, int stage, int index, ocp_nlp_out *sens_nlp_out);
|
||||
|
||||
+5
-5
@@ -45,11 +45,11 @@ extern "C" {
|
||||
|
||||
typedef enum
|
||||
{
|
||||
ERK,
|
||||
IRK,
|
||||
GNSF,
|
||||
LIFTED_IRK,
|
||||
INVALID_SIM_SOLVER,
|
||||
ERK,
|
||||
IRK,
|
||||
GNSF,
|
||||
LIFTED_IRK,
|
||||
INVALID_SIM_SOLVER,
|
||||
} sim_solver_t;
|
||||
|
||||
|
||||
|
||||
@@ -43,7 +43,28 @@
|
||||
|
||||
|
||||
|
||||
#if defined( TARGET_X64_INTEL_HASWELL )
|
||||
#if defined( TARGET_X64_INTEL_SKYLAKE_X )
|
||||
// common
|
||||
#define CACHE_LINE_SIZE 64 // data cache size: 64 bytes
|
||||
#define L1_CACHE_SIZE (32*1024) // L1 data cache size: 32 kB, 8-way
|
||||
#define L2_CACHE_SIZE (256*1024) //(1024*1024) // L2 data cache size: 1 MB ; DTLB1 64*4 kB = 256 kB
|
||||
#define LLC_CACHE_SIZE (6*1024*1024) //(8*1024*1024) // LLC cache size: 8 MB ; TLB 1536*4 kB = 6 MB
|
||||
// double
|
||||
#define D_PS 8 // panel size
|
||||
#define D_PLD 8 // 4 // GCD of panel length
|
||||
#define D_M_KERNEL 24 // max kernel size
|
||||
#define D_KC 128 //256 // 192
|
||||
#define D_NC 144 //72 //96 //72 // 120 // 512
|
||||
#define D_MC 2400 // 6000
|
||||
// single
|
||||
#define S_PS 16 // panel size
|
||||
#define S_PLD 4 // GCD of panel length TODO probably 16 when writing assebly
|
||||
#define S_M_KERNEL 32 // max kernel size
|
||||
#define S_KC 128 //256
|
||||
#define S_NC 128 //144
|
||||
#define S_MC 3000
|
||||
|
||||
#elif defined( TARGET_X64_INTEL_HASWELL )
|
||||
// common
|
||||
#define CACHE_LINE_SIZE 64 // data cache size: 64 bytes
|
||||
#define L1_CACHE_SIZE (32*1024) // L1 data cache size: 32 kB, 8-way
|
||||
|
||||
@@ -101,18 +101,18 @@ void blasfeo_dtrsv_ltu(int m, struct blasfeo_dmat *sA, int ai, int aj, struct bl
|
||||
void blasfeo_dtrsv_unn(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
|
||||
// z <= inv( A^T ) * x, A (m)x(m) upper, transposed, not_unit
|
||||
void blasfeo_dtrsv_utn(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
|
||||
// z <= A * x ; A lower triangular
|
||||
void blasfeo_dtrmv_lnn(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
|
||||
// z <= A * x ; A lower triangular, unit diagonal
|
||||
void blasfeo_dtrmv_lnu(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
|
||||
// z <= A^T * x ; A lower triangular
|
||||
void blasfeo_dtrmv_ltn(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
|
||||
// z <= A^T * x ; A lower triangular, unit diagonal
|
||||
void blasfeo_dtrmv_ltu(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
|
||||
// z <= beta * y + alpha * A * x ; A upper triangular
|
||||
void blasfeo_dtrmv_unn(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
|
||||
// z <= A^T * x ; A upper triangular
|
||||
void blasfeo_dtrmv_utn(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
|
||||
// z <= A * x ; A lower triangular
|
||||
void blasfeo_dtrmv_lnn(int m, int n, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
|
||||
// z <= A^T * x ; A lower triangular
|
||||
void blasfeo_dtrmv_ltn(int m, int n, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
|
||||
// z <= A * x ; A lower triangular, unit diagonal
|
||||
void blasfeo_dtrmv_lnu(int m, int n, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
|
||||
// z <= A^T * x ; A lower triangular, unit diagonal
|
||||
void blasfeo_dtrmv_ltu(int m, int n, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
|
||||
// z_n <= beta_n * y_n + alpha_n * A * x_n
|
||||
// z_t <= beta_t * y_t + alpha_t * A^T * x_t
|
||||
void blasfeo_dgemv_nt(int m, int n, double alpha_n, double alpha_t, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx_n, int xi_n, struct blasfeo_dvec *sx_t, int xi_t, double beta_n, double beta_t, struct blasfeo_dvec *sy_n, int yi_n, struct blasfeo_dvec *sy_t, int yi_t, struct blasfeo_dvec *sz_n, int zi_n, struct blasfeo_dvec *sz_t, int zi_t);
|
||||
|
||||
@@ -101,23 +101,24 @@ void blasfeo_ref_dtrsv_ltu(int m, struct blasfeo_dmat *sA, int ai, int aj, struc
|
||||
void blasfeo_ref_dtrsv_unn(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
|
||||
// z <= inv( A' ) * x, A (m)x(m) upper, transposed, not_unit
|
||||
void blasfeo_ref_dtrsv_utn(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
|
||||
// z <= A * x ; A lower triangular
|
||||
void blasfeo_ref_dtrmv_lnn(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
|
||||
// z <= A * x ; A lower triangular, unit diagonal
|
||||
void blasfeo_ref_dtrmv_lnu(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
|
||||
// z <= A' * x ; A lower triangular
|
||||
void blasfeo_ref_dtrmv_ltn(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
|
||||
// z <= A' * x ; A lower triangular, unit diagonal
|
||||
void blasfeo_ref_dtrmv_ltu(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
|
||||
// z <= beta * y + alpha * A * x ; A upper triangular
|
||||
void blasfeo_ref_dtrmv_unn(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
|
||||
// z <= A' * x ; A upper triangular
|
||||
void blasfeo_ref_dtrmv_utn(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
|
||||
// z <= A * x ; A lower triangular
|
||||
void blasfeo_ref_dtrmv_lnn(int m, int n, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
|
||||
// z <= A' * x ; A lower triangular
|
||||
void blasfeo_ref_dtrmv_ltn(int m, int n, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
|
||||
// z <= A * x ; A lower triangular, unit diagonal
|
||||
void blasfeo_ref_dtrmv_lnu(int m, int n, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
|
||||
// z <= A' * x ; A lower triangular, unit diagonal
|
||||
void blasfeo_ref_dtrmv_ltu(int m, int n, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
|
||||
// z_n <= beta_n * y_n + alpha_n * A * x_n
|
||||
// z_t <= beta_t * y_t + alpha_t * A' * x_t
|
||||
void blasfeo_ref_dgemv_nt(int m, int n, double alpha_n, double alpha_t, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx_n, int xi_n, struct blasfeo_dvec *sx_t, int xi_t, double beta_n, double beta_t, struct blasfeo_dvec *sy_n, int yi_n, struct blasfeo_dvec *sy_t, int yi_t, struct blasfeo_dvec *sz_n, int zi_n, struct blasfeo_dvec *sz_t, int zi_t);
|
||||
// z <= beta * y + alpha * A * x, where A is symmetric and only the lower triangular patr of A is accessed
|
||||
void blasfeo_ref_dsymv_l(int m, int n, double alpha, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, double beta, struct blasfeo_dvec *sy, int yi, struct blasfeo_dvec *sz, int zi);
|
||||
void blasfeo_ref_dsymv_l(int m, double alpha, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, double beta, struct blasfeo_dvec *sy, int yi, struct blasfeo_dvec *sz, int zi);
|
||||
void blasfeo_ref_dsymv_l_mn(int m, int n, double alpha, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, double beta, struct blasfeo_dvec *sy, int yi, struct blasfeo_dvec *sz, int zi);
|
||||
|
||||
// diagonal
|
||||
|
||||
|
||||
+205
-6
@@ -54,6 +54,158 @@ void blasfeo_align_4096_byte(void *ptr, void **ptr_align);
|
||||
void blasfeo_align_64_byte(void *ptr, void **ptr_align);
|
||||
|
||||
|
||||
//
|
||||
// lib8
|
||||
//
|
||||
|
||||
// 24x8
|
||||
void kernel_dgemm_nt_24x8_lib8(int k, double *alpha, double *A, int sda, double *B, double *beta, double *C, int sdc, double *D, int sdd); //
|
||||
void kernel_dgemm_nt_24x8_vs_lib8(int k, double *alpha, double *A, int sda, double *B, double *beta, double *C, int sdc, double *D, int sdd, int m1, int n1); //
|
||||
void kernel_dtrsm_nt_rl_inv_24x8_lib8(int k, double *A, int sda, double *B, double *beta, double *C, int sdc, double *D, int sdd, double *E, double *inv_diag_E); //
|
||||
void kernel_dpotrf_nt_l_24x8_lib8(int k, double *A, int sda, double *B, double *C, int sdc, double *D, int sdd, double *inv_diag_D);
|
||||
void kernel_dpotrf_nt_l_24x8_vs_lib8(int k, double *A, int sda, double *B, double *C, int sdc, double *D, int sdd, double *inv_diag_D, int m1, int n1);
|
||||
void kernel_dtrsm_nt_rl_inv_24x8_vs_lib8(int k, double *A, int sda, double *B, double *beta, double *C, int sdc, double *D, int sdd, double *E, double *inv_diag_E, int m1, int n1); //
|
||||
void kernel_dgemm_dtrsm_nt_rl_inv_24x8_lib8(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *E, double *inv_diag_E);
|
||||
void kernel_dgemm_dtrsm_nt_rl_inv_24x8_vs_lib8(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *E, double *inv_diag_E, int m1, int n1);
|
||||
void kernel_dsyrk_dpotrf_nt_l_24x8_lib8(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *inv_diag_D);
|
||||
void kernel_dsyrk_dpotrf_nt_l_24x8_vs_lib8(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *inv_diag_D, int m1, int n1);
|
||||
void kernel_dlarfb8_rn_24_lib8(int kmax, double *pV, double *pT, double *pD, int sdd);
|
||||
// 16x8
|
||||
void kernel_dgemm_nt_16x8_lib8(int k, double *alpha, double *A, int sda, double *B, double *beta, double *C, int sdc, double *D, int sdd); //
|
||||
void kernel_dgemm_nt_16x8_vs_lib8(int k, double *alpha, double *A, int sda, double *B, double *beta, double *C, int sdc, double *D, int sdd, int m1, int n1); //
|
||||
void kernel_dgemm_nt_16x8_gen_lib8(int k, double *alpha, double *A, int sda, double *B, double *beta, int offC, double *C, int sdc, int offD, double *D, int sdd, int m0, int m1, int n0, int n1); //
|
||||
void kernel_dgemm_nn_16x8_lib8(int k, double *alpha, double *A, int sda, int offB, double *B, int sdb, double *beta, double *C, int sdc, double *D, int sdd); //
|
||||
void kernel_dgemm_nn_16x8_vs_lib8(int k, double *alpha, double *A, int sda, int offB, double *B, int sdb, double *beta, double *C, int sdc, double *D, int sdd, int m1, int n1); //
|
||||
void kernel_dgemm_nn_16x8_gen_lib8(int k, double *alpha, double *A, int sda, int offB, double *B, int sdb, double *beta, int offC, double *C, int sdc, int offD, double *D, int sdd, int m0, int m1, int n0, int n1); //
|
||||
void kernel_dsyrk_nt_l_16x8_lib8(int k, double *alpha, double *A, int sda, double *B, double *beta, double *C, int sdc, double *D, int sdd); //
|
||||
void kernel_dsyrk_nt_l_16x8_vs_lib8(int k, double *alpha, double *A, int sda, double *B, double *beta, double *C, int sdc, double *D, int sdd, int m1, int n1); //
|
||||
void kernel_dsyrk_nt_l_16x8_gen_lib8(int k, double *alpha, double *A, int sda, double *B, double *beta, int offC, double *C, int sdc, int offD, double *D, int sdd, int m0, int m1, int n0, int n1); //
|
||||
void kernel_dtrmm_nn_rl_16x8_lib8(int k, double *alpha, double *A, int sda, int offsetB, double *B, int sdb, double *D, int sdd);
|
||||
void kernel_dtrmm_nn_rl_16x8_vs_lib8(int k, double *alpha, double *A, int sda, int offsetB, double *B, int sdb, double *D, int sdd, int m1, int n1);
|
||||
void kernel_dtrmm_nn_rl_16x8_gen_lib8(int k, double *alpha, double *A, int sda, int offsetB, double *B, int sdb, int offD, double *D, int sdd, int m0, int m1, int n0, int n1);
|
||||
void kernel_dtrsm_nt_rl_inv_16x8_lib8(int k, double *A, int sda, double *B, double *beta, double *C, int sdc, double *D, int sdd, double *E, double *inv_diag_E); //
|
||||
void kernel_dtrsm_nt_rl_inv_16x8_vs_lib8(int k, double *A, int sda, double *B, double *beta, double *C, int sdc, double *D, int sdd, double *E, double *inv_diag_E, int m1, int n1); //
|
||||
void kernel_dpotrf_nt_l_16x8_lib8(int k, double *A, int sda, double *B, double *C, int sdc, double *D, int sdd, double *inv_diag_D);
|
||||
void kernel_dpotrf_nt_l_16x8_vs_lib8(int k, double *A, int sda, double *B, double *C, int sdc, double *D, int sdd, double *inv_diag_D, int m1, int n1);
|
||||
void kernel_dgemm_dtrsm_nt_rl_inv_16x8_lib8(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *E, double *inv_diag_E);
|
||||
void kernel_dgemm_dtrsm_nt_rl_inv_16x8_vs_lib8(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *E, double *inv_diag_E, int m1, int n1);
|
||||
void kernel_dsyrk_dpotrf_nt_l_16x8_lib8(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *inv_diag_D);
|
||||
void kernel_dsyrk_dpotrf_nt_l_16x8_vs_lib8(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *inv_diag_D, int m1, int n1);
|
||||
void kernel_dlarfb8_rn_16_lib8(int kmax, double *pV, double *pT, double *pD, int sdd);
|
||||
void kernel_dlarfb8_rn_la_16_lib8(int n1, double *pVA, double *pT, double *pD, int sdd, double *pA, int sda);
|
||||
void kernel_dlarfb8_rn_lla_16_lib8(int n0, int n1, double *pVL, double *pVA, double *pT, double *pD, int sdd, double *pL, int sdl, double *pA, int sda);
|
||||
// 8x16
|
||||
void kernel_dgemm_tt_8x16_lib8(int k, double *alpha, int offA, double *A, int sda, double *B, int sdb, double *beta, double *C, double *D); //
|
||||
void kernel_dgemm_tt_8x16_vs_lib8(int k, double *alpha, int offA, double *A, int sda, double *B, int sdb, double *beta, double *C, double *D, int m1, int n1); //
|
||||
void kernel_dgemm_tt_8x16_gen_lib8(int k, double *alpha, int offA, double *A, int sda, double *B, int sdb, double *beta, int offC, double *C, int sdc, int offD, double *D, int sdd, int m0, int m1, int n0, int n1); //
|
||||
void kernel_dgemm_nt_8x16_lib8(int k, double *alpha, double *A, double *B, int sdb, double *beta, double *C, double *D); //
|
||||
void kernel_dgemm_nt_8x16_vs_lib8(int k, double *alpha, double *A, double *B, int sdb, double *beta, double *C, double *D, int m1, int n1); //
|
||||
// 8x8
|
||||
void kernel_dgemm_nt_8x8_lib8(int k, double *alpha, double *A, double *B, double *beta, double *C, double *D); //
|
||||
void kernel_dgemm_nt_8x8_vs_lib8(int k, double *alpha, double *A, double *B, double *beta, double *C, double *D, int m1, int n1); //
|
||||
void kernel_dgemm_nt_8x8_gen_lib8(int k, double *alpha, double *A, double *B, double *beta, int offC, double *C, int sdc, int offD, double *D, int sdd, int m0, int m1, int n0, int n1); //
|
||||
void kernel_dgemm_nn_8x8_lib8(int k, double *alpha, double *A, int offB, double *B, int sdb, double *beta, double *C, double *D); //
|
||||
void kernel_dgemm_nn_8x8_vs_lib8(int k, double *alpha, double *A, int offB, double *B, int sdb, double *beta, double *C, double *D, int m1, int n1); //
|
||||
void kernel_dgemm_nn_8x8_gen_lib8(int k, double *alpha, double *A, int offB, double *B, int sdb, double *beta, int offC, double *C, int sdc, int offD, double *D, int sdd, int m0, int m1, int n0, int n1); //
|
||||
void kernel_dgemm_tt_8x8_lib8(int k, double *alpha, int offA, double *A, int sda, double *B, double *beta, double *C, double *D); //
|
||||
void kernel_dgemm_tt_8x8_vs_lib8(int k, double *alpha, int offA, double *A, int sda, double *B, double *beta, double *C, double *D, int m1, int n1); //
|
||||
void kernel_dgemm_tt_8x8_gen_lib8(int k, double *alpha, int offA, double *A, int sda, double *B, double *beta, int offc, double *C, int sdc, int offD, double *D, int sdd, int m0, int m1, int n0, int n1); //
|
||||
void kernel_dsyrk_nt_l_8x8_lib8(int k, double *alpha, double *A, double *B, double *beta, double *C, double *D); //
|
||||
void kernel_dsyrk_nt_l_8x8_vs_lib8(int k, double *alpha, double *A, double *B, double *beta, double *C, double *D, int m1, int n1); //
|
||||
void kernel_dsyrk_nt_l_8x8_gen_lib8(int k, double *alpha, double *A, double *B, double *beta, int offC, double *C, int sdc, int offD, double *D, int sdd, int m0, int m1, int n0, int n1); //
|
||||
void kernel_dtrmm_nn_rl_8x8_lib8(int k, double *alpha, double *A, int offsetB, double *B, int sdb, double *D);
|
||||
void kernel_dtrmm_nn_rl_8x8_vs_lib8(int k, double *alpha, double *A, int offsetB, double *B, int sdb, double *D, int m1, int n1);
|
||||
void kernel_dtrmm_nn_rl_8x8_gen_lib8(int k, double *alpha, double *A, int offsetB, double *B, int sdb, int offD, double *D, int sdd, int m0, int m1, int n0, int n1);
|
||||
void kernel_dtrsm_nt_rl_inv_8x8_lib8(int k, double *A, double *B, double *beta, double *C, double *D, double *E, double *inv_diag_E);
|
||||
void kernel_dtrsm_nt_rl_inv_8x8_vs_lib8(int k, double *A, double *B, double *beta, double *C, double *D, double *E, double *inv_diag_E, int m1, int n1);
|
||||
void kernel_dpotrf_nt_l_8x8_lib8(int k, double *A, double *B, double *C, double *D, double *inv_diag_D);
|
||||
void kernel_dpotrf_nt_l_8x8_vs_lib8(int k, double *A, double *B, double *C, double *D, double *inv_diag_D, int m1, int n1);
|
||||
void kernel_dgemm_dtrsm_nt_rl_inv_8x8_lib8(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *E, double *inv_diag_E);
|
||||
void kernel_dgemm_dtrsm_nt_rl_inv_8x8_vs_lib8(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *E, double *inv_diag_E, int m1, int n1);
|
||||
void kernel_dsyrk_dpotrf_nt_l_8x8_lib8(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *inv_diag_D);
|
||||
void kernel_dsyrk_dpotrf_nt_l_8x8_vs_lib8(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *inv_diag_D, int m1, int n1);
|
||||
void kernel_dgelqf_vs_lib8(int m, int n, int k, int offD, double *pD, int sdd, double *dD);
|
||||
void kernel_dgelqf_pd_vs_lib8(int m, int n, int k, int offD, double *pD, int sdd, double *dD);
|
||||
void kernel_dgelqf_8_lib8(int kmax, double *pD, double *dD);
|
||||
void kernel_dgelqf_pd_8_lib8(int kmax, double *pD, double *dD);
|
||||
void kernel_dlarft_8_lib8(int kmax, double *pD, double *dD, double *pT);
|
||||
void kernel_dlarfb8_rn_8_lib8(int kmax, double *pV, double *pT, double *pD);
|
||||
void kernel_dlarfb8_rn_8_vs_lib8(int kmax, double *pV, double *pT, double *pD, int m1);
|
||||
void kernel_dlarfb8_rn_1_lib8(int kmax, double *pV, double *pT, double *pD);
|
||||
void kernel_dgelqf_dlarft8_8_lib8(int kmax, double *pD, double *dD, double *pT);
|
||||
void kernel_dgelqf_pd_dlarft8_8_lib8(int kmax, double *pD, double *dD, double *pT);
|
||||
void kernel_dgelqf_pd_la_vs_lib8(int m, int n1, int k, int offD, double *pD, int sdd, double *dD, int offA, double *pA, int sda);
|
||||
void kernel_dgelqf_pd_la_dlarft8_8_lib8(int kmax, double *pD, double *dD, double *pA, double *pT);
|
||||
void kernel_dlarft_la_8_lib8(int n1, double *dD, double *pA, double *pT);
|
||||
void kernel_dlarfb8_rn_la_8_lib8(int n1, double *pVA, double *pT, double *pD, double *pA);
|
||||
void kernel_dlarfb8_rn_la_8_vs_lib8(int n1, double *pVA, double *pT, double *pD, double *pA, int m1);
|
||||
void kernel_dlarfb8_rn_la_1_lib8(int n1, double *pVA, double *pT, double *pD, double *pA);
|
||||
void kernel_dgelqf_pd_lla_vs_lib8(int m, int n0, int n1, int k, int offD, double *pD, int sdd, double *dD, int offL, double *pL, int sdl, int offA, double *pA, int sda);
|
||||
void kernel_dgelqf_pd_lla_dlarft8_8_lib8(int n0, int n1, double *pD, double *dD, double *pL, double *pA, double *pT);
|
||||
void kernel_dlarft_lla_8_lib8(int n0, int n1, double *dD, double *pL, double *pA, double *pT);
|
||||
void kernel_dlarfb8_rn_lla_8_lib8(int n0, int n1, double *pVL, double *pVA, double *pT, double *pD, double *pL, double *pA);
|
||||
void kernel_dlarfb8_rn_lla_8_vs_lib8(int n0, int n1, double *pVL, double *pVA, double *pT, double *pD, double *pL, double *pA, int m1);
|
||||
void kernel_dlarfb8_rn_lla_1_lib8(int n0, int n1, double *pVL, double *pVA, double *pT, double *pD, double *pL, double *pA);
|
||||
|
||||
// panel copy / pack
|
||||
// 24
|
||||
void kernel_dpack_nn_24_lib8(int kmax, double *A, int lda, double *C, int sdc);
|
||||
void kernel_dpack_nn_24_vs_lib8(int kmax, double *A, int lda, double *C, int sdc, int m1);
|
||||
// 16
|
||||
void kernel_dpacp_nn_16_lib8(int kmax, int offsetA, double *A, int sda, double *B, int sdb);
|
||||
void kernel_dpacp_nn_16_vs_lib8(int kmax, int offsetA, double *A, int sda, double *B, int sdb, int m1);
|
||||
void kernel_dpack_nn_16_lib8(int kmax, double *A, int lda, double *C, int sdc);
|
||||
void kernel_dpack_nn_16_vs_lib8(int kmax, double *A, int lda, double *C, int sdc, int m1);
|
||||
// 8
|
||||
void kernel_dpacp_nn_8_lib8(int kmax, int offsetA, double *A, int sda, double *B);
|
||||
void kernel_dpacp_nn_8_vs_lib8(int kmax, int offsetA, double *A, int sda, double *B, int m1);
|
||||
void kernel_dpacp_tn_8_lib8(int kmax, int offsetA, double *A, int sda, double *B);
|
||||
void kernel_dpacp_tn_8_vs_lib8(int kmax, int offsetA, double *A, int sda, double *B, int m1);
|
||||
void kernel_dpacp_l_nn_8_lib8(int kmax, int offsetA, double *A, int sda, double *B);
|
||||
void kernel_dpacp_l_nn_8_vs_lib8(int kmax, int offsetA, double *A, int sda, double *B, int m1);
|
||||
void kernel_dpacp_l_tn_8_lib8(int kmax, int offsetA, double *A, int sda, double *B);
|
||||
void kernel_dpacp_l_tn_8_vs_lib8(int kmax, int offsetA, double *A, int sda, double *B, int m1);
|
||||
void kernel_dpaad_nn_8_lib8(int kmax, double *alpha, int offsetA, double *A, int sda, double *B);
|
||||
void kernel_dpaad_nn_8_vs_lib8(int kmax, double *alpha, int offsetA, double *A, int sda, double *B, int m1);
|
||||
void kernel_dpack_nn_8_lib8(int kmax, double *A, int lda, double *C);
|
||||
void kernel_dpack_nn_8_vs_lib8(int kmax, double *A, int lda, double *C, int m1);
|
||||
void kernel_dpack_tn_8_lib8(int kmax, double *A, int lda, double *C);
|
||||
void kernel_dpack_tn_8_vs_lib8(int kmax, double *A, int lda, double *C, int m1);
|
||||
// 4
|
||||
void kernel_dpack_tt_4_lib8(int kmax, double *A, int lda, double *C, int sdc); // TODO offsetC
|
||||
void kernel_dpack_tt_4_vs_lib8(int kmax, double *A, int lda, double *C, int sdc, int m1); // TODO offsetC
|
||||
|
||||
// level 2 BLAS
|
||||
// 16
|
||||
void kernel_dgemv_n_16_lib8(int k, double *alpha, double *A, int sda, double *x, double *beta, double *y, double *z);
|
||||
// 8
|
||||
void kernel_dgemv_n_8_lib8(int k, double *alpha, double *A, double *x, double *beta, double *y, double *z);
|
||||
void kernel_dgemv_n_8_vs_lib8(int k, double *alpha, double *A, double *x, double *beta, double *y, double *z, int m1);
|
||||
//void kernel_dgemv_n_8_gen_lib8(int k, double *alpha, double *A, double *x, double *beta, double *y, double *z, int m0, int m1);
|
||||
void kernel_dgemv_n_8_gen_lib8(int k, double *alpha, int offsetA, double *A, double *x, double *beta, double *y, double *z, int m1);
|
||||
void kernel_dgemv_t_8_lib8(int k, double *alpha, int offsetA, double *A, int sda, double *x, double *beta, double *y, double *z);
|
||||
void kernel_dgemv_t_8_vs_lib8(int k, double *alpha, int offsetA, double *A, int sda, double *x, double *beta, double *y, double *z, int n1);
|
||||
void kernel_dgemv_nt_8_lib8(int kmax, double *alpha_n, double *alpha_t, int offsetA, double *A, int sda, double *x_n, double *x_t, double *beta_t, double *y_t, double *z_n, double *z_t);
|
||||
void kernel_dgemv_nt_8_vs_lib8(int kmax, double *alpha_n, double *alpha_t, int offsetA, double *A, int sda, double *x_n, double *x_t, double *beta_t, double *y_t, double *z_n, double *z_t, int n1);
|
||||
void kernel_dsymv_l_8_lib8(int kmax, double *alpha, double *A, int sda, double *x, double *z);
|
||||
void kernel_dsymv_l_8_vs_lib8(int kmax, double *alpha, double *A, int sda, double *x, double *z, int n1);
|
||||
void kernel_dsymv_l_8_gen_lib8(int kmax, double *alpha, int offsetA, double *A, int sda, double *x, double *z, int n1);
|
||||
void kernel_dtrmv_n_ln_8_lib8(int k, double *A, double *x, double *z);
|
||||
void kernel_dtrmv_n_ln_8_vs_lib8(int k, double *A, double *x, double *z, int m1);
|
||||
void kernel_dtrmv_n_ln_8_gen_lib8(int k, int offsetA, double *A, double *x, double *z, int m1);
|
||||
void kernel_dtrmv_t_ln_8_lib8(int k, double *A, int sda, double *x, double *z);
|
||||
void kernel_dtrmv_t_ln_8_vs_lib8(int k, double *A, int sda, double *x, double *z, int n1);
|
||||
void kernel_dtrmv_t_ln_8_gen_lib8(int k, int offsetA, double *A, int sda, double *x, double *z, int n1);
|
||||
void kernel_dtrsv_n_l_inv_8_lib8(int k, double *A, double *inv_diag_A, double *x, double *z);
|
||||
void kernel_dtrsv_n_l_inv_8_vs_lib8(int k, double *A, double *inv_diag_A, double *x, double *z, int m1, int n1);
|
||||
void kernel_dtrsv_t_l_inv_8_lib8(int k, double *A, int sda, double *inv_diag_A, double *x, double *z);
|
||||
void kernel_dtrsv_t_l_inv_8_vs_lib8(int k, double *A, int sda, double *inv_diag_A, double *x, double *z, int m1, int n1);
|
||||
|
||||
|
||||
|
||||
//
|
||||
// lib4
|
||||
//
|
||||
|
||||
// level 2 BLAS
|
||||
// 12
|
||||
@@ -413,10 +565,10 @@ void kernel_drowsw_lib4(int kmax, double *pA, double *pC);
|
||||
|
||||
// merged routines
|
||||
// 12x4
|
||||
void kernel_dgemm_dtrsm_nt_rl_inv_12x4_vs_lib4(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *E, double *inv_diag_E, int km, int kn);
|
||||
void kernel_dgemm_dtrsm_nt_rl_inv_12x4_lib4(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *E, double *inv_diag_E);
|
||||
void kernel_dsyrk_dpotrf_nt_l_12x4_vs_lib4(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *inv_diag_D, int km, int kn);
|
||||
void kernel_dgemm_dtrsm_nt_rl_inv_12x4_vs_lib4(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *E, double *inv_diag_E, int km, int kn);
|
||||
void kernel_dsyrk_dpotrf_nt_l_12x4_lib4(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *inv_diag_D);
|
||||
void kernel_dsyrk_dpotrf_nt_l_12x4_vs_lib4(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *inv_diag_D, int km, int kn);
|
||||
// 4x12
|
||||
void kernel_dgemm_dtrsm_nt_rl_inv_4x12_vs_lib4(int kp, double *Ap, double *Bp, int sdbp, int km_, double *Am, double *Bm, int sdbm, double *C, double *D, double *E, int sde, double *inv_diag_E, int km, int kn);
|
||||
// 8x8
|
||||
@@ -425,8 +577,8 @@ void kernel_dsyrk_dpotrf_nt_l_8x8_vs_lib4(int kp, double *Ap, int sdap, double *
|
||||
void kernel_dgemm_dtrsm_nt_rl_inv_8x8l_vs_lib4(int kp, double *Ap, int sdap, double *Bp, int sdb, int km_, double *Am, int sdam, double *Bm, int sdbm, double *C, int sdc, double *D, int sdd, double *E, int sde, double *inv_diag_E, int km, int kn);
|
||||
void kernel_dgemm_dtrsm_nt_rl_inv_8x8u_vs_lib4(int kp, double *Ap, int sdap, double *Bp, int sdb, int km_, double *Am, int sdam, double *Bm, int sdbm, double *C, int sdc, double *D, int sdd, double *E, int sde, double *inv_diag_E, int km, int kn);
|
||||
// 8x4
|
||||
void kernel_dgemm_dtrsm_nt_rl_inv_8x4_vs_lib4(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *E, double *inv_diag_E, int km, int kn);
|
||||
void kernel_dgemm_dtrsm_nt_rl_inv_8x4_lib4(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *E, double *inv_diag_E);
|
||||
void kernel_dgemm_dtrsm_nt_rl_inv_8x4_vs_lib4(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *E, double *inv_diag_E, int km, int kn);
|
||||
void kernel_dsyrk_dpotrf_nt_l_8x4_lib4(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *inv_diag_D);
|
||||
void kernel_dsyrk_dpotrf_nt_l_8x4_vs_lib4(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *inv_diag_D, int km, int kn);
|
||||
// 4x8
|
||||
@@ -434,16 +586,16 @@ void kernel_dgemm_dtrsm_nt_rl_inv_4x8_vs_lib4(int kp, double *Ap, double *Bp, in
|
||||
// 4x4
|
||||
void kernel_dgemm_dtrsm_nt_rl_inv_4x4_lib4(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *E, double *inv_diag_E);
|
||||
void kernel_dgemm_dtrsm_nt_rl_inv_4x4_vs_lib4(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *E, double *inv_diag_E, int km, int kn);
|
||||
void kernel_dsyrk_dpotrf_nt_l_4x4_vs_lib4(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *inv_diag_D, int km, int kn);
|
||||
void kernel_dsyrk_dpotrf_nt_l_4x4_lib4(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *inv_diag_D);
|
||||
void kernel_dsyrk_dpotrf_nt_l_4x4_vs_lib4(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *inv_diag_D, int km, int kn);
|
||||
// 4x2
|
||||
void kernel_dgemm_dtrsm_nt_rl_inv_4x2_lib4(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *E, double *inv_diag_E);
|
||||
void kernel_dgemm_dtrsm_nt_rl_inv_4x2_vs_lib4(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *E, double *inv_diag_E, int km, int kn);
|
||||
void kernel_dsyrk_dpotrf_nt_l_4x2_vs_lib4(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *inv_diag_D, int km, int kn);
|
||||
void kernel_dsyrk_dpotrf_nt_l_4x2_lib4(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *inv_diag_D);
|
||||
void kernel_dsyrk_dpotrf_nt_l_4x2_vs_lib4(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *inv_diag_D, int km, int kn);
|
||||
// 2x2
|
||||
void kernel_dsyrk_dpotrf_nt_l_2x2_vs_lib4(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *inv_diag_D, int km, int kn);
|
||||
void kernel_dsyrk_dpotrf_nt_l_2x2_lib4(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *inv_diag_D);
|
||||
void kernel_dsyrk_dpotrf_nt_l_2x2_vs_lib4(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *inv_diag_D, int km, int kn);
|
||||
|
||||
/*
|
||||
*
|
||||
@@ -1034,6 +1186,53 @@ void kernel_dgemm_nt_8xn_p0_lib44cc(int n, int k, double *alpha, double *A, int
|
||||
|
||||
|
||||
|
||||
// A, B panel-major bs=8; C, D column-major
|
||||
// 24x8
|
||||
void kernel_dgemm_nt_24x8_lib88cc(int kmax, double *alpha, double *A, int sda, double *B, double *beta, double *C, int ldc, double *D, int ldd);
|
||||
void kernel_dgemm_nt_24x8_vs_lib88cc(int kmax, double *alpha, double *A, int sda, double *B, double *beta, double *C, int ldc, double *D, int ldd, int m1, int n1);
|
||||
// 16x8
|
||||
void kernel_dgemm_nt_16x8_lib88cc(int kmax, double *alpha, double *A, int sda, double *B, double *beta, double *C, int ldc, double *D, int ldd);
|
||||
void kernel_dgemm_nt_16x8_vs_lib88cc(int kmax, double *alpha, double *A, int sda, double *B, double *beta, double *C, int ldc, double *D, int ldd, int m1, int n1);
|
||||
// 8x8
|
||||
void kernel_dgemm_nt_8x8_lib88cc(int kmax, double *alpha, double *A, double *B, double *beta, double *C, int ldc, double *D, int ldd);
|
||||
void kernel_dgemm_nt_8x8_vs_lib88cc(int kmax, double *alpha, double *A, double *B, double *beta, double *C, int ldc, double *D, int ldd, int m1, int n1);
|
||||
|
||||
// A, panel-major bs=8; B, C, D column-major
|
||||
// 24x8
|
||||
void kernel_dgemm_nt_24x8_lib8ccc(int kmax, double *alpha, double *A, int sda, double *B, int ldb, double *beta, double *C, int ldc, double *D, int ldd);
|
||||
void kernel_dgemm_nt_24x8_vs_lib8ccc(int kmax, double *alpha, double *A, int sda, double *B, int ldb, double *beta, double *C, int ldc, double *D, int ldd, int m1, int n1);
|
||||
void kernel_dgemm_nn_24x8_lib8ccc(int kmax, double *alpha, double *A, int sda, double *B, int ldb, double *beta, double *C, int ldc, double *D, int ldd);
|
||||
void kernel_dgemm_nn_24x8_vs_lib8ccc(int kmax, double *alpha, double *A, int sda, double *B, int ldb, double *beta, double *C, int ldc, double *D, int ldd, int m1, int n1);
|
||||
// 16x8
|
||||
void kernel_dgemm_nt_16x8_lib8ccc(int kmax, double *alpha, double *A, int sda, double *B, int ldb, double *beta, double *C, int ldc, double *D, int ldd);
|
||||
void kernel_dgemm_nt_16x8_vs_lib8ccc(int kmax, double *alpha, double *A, int sda, double *B, int ldb, double *beta, double *C, int ldc, double *D, int ldd, int m1, int n1);
|
||||
void kernel_dgemm_nn_16x8_lib8ccc(int kmax, double *alpha, double *A, int sda, double *B, int ldb, double *beta, double *C, int ldc, double *D, int ldd);
|
||||
void kernel_dgemm_nn_16x8_vs_lib8ccc(int kmax, double *alpha, double *A, int sda, double *B, int ldb, double *beta, double *C, int ldc, double *D, int ldd, int m1, int n1);
|
||||
// 8x8
|
||||
void kernel_dgemm_nt_8x8_lib8ccc(int kmax, double *alpha, double *A, double *B, int ldb, double *beta, double *C, int ldc, double *D, int ldd);
|
||||
void kernel_dgemm_nt_8x8_vs_lib8ccc(int kmax, double *alpha, double *A, double *B, int ldb, double *beta, double *C, int ldc, double *D, int ldd, int m1, int n1);
|
||||
void kernel_dgemm_nn_8x8_lib8ccc(int kmax, double *alpha, double *A, double *B, int ldb, double *beta, double *C, int ldc, double *D, int ldd);
|
||||
void kernel_dgemm_nn_8x8_vs_lib8ccc(int kmax, double *alpha, double *A, double *B, int ldb, double *beta, double *C, int ldc, double *D, int ldd, int m1, int n1);
|
||||
|
||||
// B, panel-major bs=8; A, C, D column-major
|
||||
// 8x24
|
||||
void kernel_dgemm_nt_8x24_libc8cc(int kmax, double *alpha, double *A, int lda, double *B, int sdb, double *beta, double *C, int ldc, double *D, int ldd);
|
||||
void kernel_dgemm_nt_8x24_vs_libc8cc(int kmax, double *alpha, double *A, int lda, double *B, int sdb, double *beta, double *C, int ldc, double *D, int ldd, int m1, int n1);
|
||||
void kernel_dgemm_tt_8x24_libc8cc(int kmax, double *alpha, double *A, int lda, double *B, int sdb, double *beta, double *C, int ldc, double *D, int ldd);
|
||||
void kernel_dgemm_tt_8x24_vs_libc8cc(int kmax, double *alpha, double *A, int lda, double *B, int sdb, double *beta, double *C, int ldc, double *D, int ldd, int m1, int n1);
|
||||
// 8x16
|
||||
void kernel_dgemm_nt_8x16_libc8cc(int kmax, double *alpha, double *A, int lda, double *B, int sdb, double *beta, double *C, int ldc, double *D, int ldd);
|
||||
void kernel_dgemm_nt_8x16_vs_libc8cc(int kmax, double *alpha, double *A, int lda, double *B, int sdb, double *beta, double *C, int ldc, double *D, int ldd, int m1, int n1);
|
||||
void kernel_dgemm_tt_8x16_libc8cc(int kmax, double *alpha, double *A, int lda, double *B, int sdb, double *beta, double *C, int ldc, double *D, int ldd);
|
||||
void kernel_dgemm_tt_8x16_vs_libc8cc(int kmax, double *alpha, double *A, int lda, double *B, int sdb, double *beta, double *C, int ldc, double *D, int ldd, int m1, int n1);
|
||||
// 8x8
|
||||
void kernel_dgemm_nt_8x8_libc8cc(int kmax, double *alpha, double *A, int lda, double *B, double *beta, double *C, int ldc, double *D, int ldd);
|
||||
void kernel_dgemm_nt_8x8_vs_libc8cc(int kmax, double *alpha, double *A, int lda, double *B, double *beta, double *C, int ldc, double *D, int ldd, int m1, int n1);
|
||||
void kernel_dgemm_tt_8x8_libc8cc(int kmax, double *alpha, double *A, int lda, double *B, double *beta, double *C, int ldc, double *D, int ldd);
|
||||
void kernel_dgemm_tt_8x8_vs_libc8cc(int kmax, double *alpha, double *A, int lda, double *B, double *beta, double *C, int ldc, double *D, int ldd, int m1, int n1);
|
||||
|
||||
|
||||
|
||||
// aux
|
||||
void kernel_dvecld_inc1(int kmax, double *x);
|
||||
void kernel_dveccp_inc1(int kmax, double *x, double *y);
|
||||
|
||||
@@ -97,14 +97,18 @@ void blasfeo_strsv_ltu(int m, struct blasfeo_smat *sA, int ai, int aj, struct bl
|
||||
void blasfeo_strsv_unn(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
|
||||
// z <= inv( A' ) * x, A (m)x(m) upper, transposed, not_unit
|
||||
void blasfeo_strsv_utn(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
|
||||
// z <= A * x ; A lower triangular
|
||||
void blasfeo_strmv_lnn(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
|
||||
// z <= A * x ; A lower triangular, unit diagonal
|
||||
void blasfeo_strmv_lnu(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
|
||||
// z <= A' * x ; A lower triangular
|
||||
void blasfeo_strmv_ltn(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
|
||||
// z <= A' * x ; A lower triangular, unit diagonal
|
||||
void blasfeo_strmv_ltu(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
|
||||
// z <= beta * y + alpha * A * x ; A upper triangular
|
||||
void blasfeo_strmv_unn(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
|
||||
// z <= A' * x ; A upper triangular
|
||||
void blasfeo_strmv_utn(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
|
||||
// z <= A * x ; A lower triangular
|
||||
void blasfeo_strmv_lnn(int m, int n, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
|
||||
// z <= A' * x ; A lower triangular
|
||||
void blasfeo_strmv_ltn(int m, int n, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
|
||||
// z_n <= beta_n * y_n + alpha_n * A * x_n
|
||||
// z_t <= beta_t * y_t + alpha_t * A' * x_t
|
||||
void blasfeo_sgemv_nt(int m, int n, float alpha_n, float alpha_t, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx_n, int xi_n, struct blasfeo_svec *sx_t, int xi_t, float beta_n, float beta_t, struct blasfeo_svec *sy_n, int yi_n, struct blasfeo_svec *sy_t, int yi_t, struct blasfeo_svec *sz_n, int zi_n, struct blasfeo_svec *sz_t, int zi_t);
|
||||
|
||||
@@ -97,19 +97,24 @@ void blasfeo_ref_strsv_ltu(int m, struct blasfeo_smat *sA, int ai, int aj, struc
|
||||
void blasfeo_ref_strsv_unn(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
|
||||
// z <= inv( A' ) * x, A (m)x(m) upper, transposed, not_unit
|
||||
void blasfeo_ref_strsv_utn(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
|
||||
// z <= A * x ; A lower triangular
|
||||
void blasfeo_ref_strmv_lnn(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
|
||||
// z <= A * x ; A lower triangular, unit diagonal
|
||||
void blasfeo_ref_strmv_lnu(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
|
||||
// z <= A' * x ; A lower triangular
|
||||
void blasfeo_ref_strmv_ltn(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
|
||||
// z <= A' * x ; A lower triangular, unit diagonal
|
||||
void blasfeo_ref_strmv_ltu(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
|
||||
// z <= beta * y + alpha * A * x ; A upper triangular
|
||||
void blasfeo_ref_strmv_unn(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
|
||||
// z <= A' * x ; A upper triangular
|
||||
void blasfeo_ref_strmv_utn(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
|
||||
// z <= A * x ; A lower triangular
|
||||
void blasfeo_ref_strmv_lnn(int m, int n, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
|
||||
// z <= A' * x ; A lower triangular
|
||||
void blasfeo_ref_strmv_ltn(int m, int n, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
|
||||
// z_n <= beta_n * y_n + alpha_n * A * x_n
|
||||
// z_t <= beta_t * y_t + alpha_t * A' * x_t
|
||||
void blasfeo_ref_sgemv_nt(int m, int n, float alpha_n, float alpha_t, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx_n, int xi_n, struct blasfeo_svec *sx_t, int xi_t, float beta_n, float beta_t, struct blasfeo_svec *sy_n, int yi_n, struct blasfeo_svec *sy_t, int yi_t, struct blasfeo_svec *sz_n, int zi_n, struct blasfeo_svec *sz_t, int zi_t);
|
||||
// z <= beta * y + alpha * A * x, where A is symmetric and only the lower triangular patr of A is accessed
|
||||
void blasfeo_ref_ssymv_l(int m, int n, float alpha, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, float beta, struct blasfeo_svec *sy, int yi, struct blasfeo_svec *sz, int zi);
|
||||
void blasfeo_ref_ssymv_l(int m, float alpha, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, float beta, struct blasfeo_svec *sy, int yi, struct blasfeo_svec *sz, int zi);
|
||||
void blasfeo_ref_ssymv_l_mn(int m, int n, float alpha, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, float beta, struct blasfeo_svec *sy, int yi, struct blasfeo_svec *sz, int zi);
|
||||
|
||||
// diagonal
|
||||
|
||||
|
||||
@@ -550,6 +550,16 @@ void kernel_spotrf_nt_l_4x4_lib44cc(int kmax, float *A, float *B, float *C, int
|
||||
void kernel_spotrf_nt_l_4x4_vs_lib44cc(int kmax, float *A, float *B, float *C, int ldc, float *D, int ldd, float *dD, int m1, int n1);
|
||||
|
||||
// B panel-major bs=8; A, C, D column-major
|
||||
// 4x24
|
||||
void kernel_sgemm_nt_4x24_libc8cc(int kmax, float *alpha, float *A, int lda, float *B, int sdb, float *beta, float *C, int ldc, float *D, int ldd);
|
||||
void kernel_sgemm_nt_4x24_vs_libc8cc(int kmax, float *alpha, float *A, int lda, float *B, int sdb, float *beta, float *C, int ldc, float *D, int ldd, int m1, int n1);
|
||||
void kernel_sgemm_tt_4x24_libc8cc(int kmax, float *alpha, float *A, int lda, float *B, int sdb, float *beta, float *C, int ldc, float *D, int ldd);
|
||||
void kernel_sgemm_tt_4x24_vs_libc8cc(int kmax, float *alpha, float *A, int lda, float *B, int sdb, float *beta, float *C, int ldc, float *D, int ldd, int m1, int n1);
|
||||
// 4x16
|
||||
void kernel_sgemm_nt_4x16_libc8cc(int kmax, float *alpha, float *A, int lda, float *B, int sdb, float *beta, float *C, int ldc, float *D, int ldd);
|
||||
void kernel_sgemm_nt_4x16_vs_libc8cc(int kmax, float *alpha, float *A, int lda, float *B, int sdb, float *beta, float *C, int ldc, float *D, int ldd, int m1, int n1);
|
||||
void kernel_sgemm_tt_4x16_libc8cc(int kmax, float *alpha, float *A, int lda, float *B, int sdb, float *beta, float *C, int ldc, float *D, int ldd);
|
||||
void kernel_sgemm_tt_4x16_vs_libc8cc(int kmax, float *alpha, float *A, int lda, float *B, int sdb, float *beta, float *C, int ldc, float *D, int ldd, int m1, int n1);
|
||||
// 8x8
|
||||
void kernel_sgemm_nt_8x8_libc8cc(int kmax, float *alpha, float *A, int lda, float *B, float *beta, float *C, int ldc, float *D, int ldd);
|
||||
void kernel_sgemm_nt_8x8_vs_libc8cc(int kmax, float *alpha, float *A, int lda, float *B, float *beta, float *C, int ldc, float *D, int ldd, int m1, int n1);
|
||||
|
||||
+2
-2
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ef769d1e86261357234b3260f1aa8a0eb317858251eff83bcfdd8cdd60068664
|
||||
size 485104
|
||||
oid sha256:9303c7ff72999cc1e799faf24b715112326eefefb842240d4da19beb8caba1d7
|
||||
size 484904
|
||||
|
||||
+1
-1
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8c043d1a78e7bf885954fa40ae308d5b7456f5b4f796461c51d00f7f47be5c59
|
||||
oid sha256:654deb9f5c9cca7c6c049f8cffc9eae40052d7971489690929b5efded4b9d9da
|
||||
size 730608
|
||||
|
||||
+1
-1
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2d51902b5af8d29750f76611c956a66f070fef9334e0945f2aec65365834c4cb
|
||||
oid sha256:c7494826f51ccbbe43da2df236216dad90bbd2f378accee9a1828f63fac80599
|
||||
size 1367352
|
||||
|
||||
+2
-2
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ad96e016d31a7195786b9dc011e94ad78d01199da77ec4288a17bf4b69f33021
|
||||
size 525968
|
||||
oid sha256:304b272189540640455e2003b6b124ce134edbc537d9f7e266939a6163e3965c
|
||||
size 521320
|
||||
|
||||
+1
-1
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c41f340411cbd4adf8d3065298c094f438d432304234bcb24403062a0a10caa5
|
||||
oid sha256:7eed704b395c5ab6eeff9e4188ee4f883889c666adcea7743cf2bdcef535f533
|
||||
size 1265064
|
||||
|
||||
+1
-1
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8d15b89d07c01f24e668961797f9e163d3aab1fb17006164fa8d1c850908fc49
|
||||
oid sha256:58abc1d388507fbab4c92483b58a4a2dd985877cd0b568c3135436cadc8f616e
|
||||
size 1572648
|
||||
|
||||
Reference in New Issue
Block a user