kronos-trader/model/__init__.py
shiyu-coder 9f946dec6b initial
2025-07-01 10:57:41 +08:00

18 lines
412 B
Python

from .kronos import KronosTokenizer, Kronos, KronosPredictor
model_dict = {
'kronos_tokenizer': KronosTokenizer,
'kronos': Kronos,
'kronos_predictor': KronosPredictor
}
def get_model_class(model_name):
if model_name in model_dict:
return model_dict[model_name]
else:
print(f"Model {model_name} not found in model_dict")
raise NotImplementedError