mirror of
https://github.com/firestar5683/StarPilot.git
synced 2026-06-28 01:52:06 +08:00
parse planplus (#37386)
This commit is contained in:
@@ -113,6 +113,8 @@ class Parser:
|
||||
plan_mhp = self.is_mhp(outs, 'plan', ModelConstants.IDX_N * ModelConstants.PLAN_WIDTH)
|
||||
plan_in_N, plan_out_N = (ModelConstants.PLAN_MHP_N, ModelConstants.PLAN_MHP_SELECTION) if plan_mhp else (0, 0)
|
||||
self.parse_mdn('plan', outs, in_N=plan_in_N, out_N=plan_out_N, out_shape=(ModelConstants.IDX_N, ModelConstants.PLAN_WIDTH))
|
||||
if 'planplus' in outs:
|
||||
self.parse_mdn('planplus', outs, in_N=0, out_N=0, out_shape=(ModelConstants.IDX_N, ModelConstants.PLAN_WIDTH))
|
||||
self.parse_categorical_crossentropy('desire_state', outs, out_shape=(ModelConstants.DESIRE_PRED_WIDTH,))
|
||||
return outs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user