importargparseimportloggingimportosfrompathlibimportPathfromtypingimportAny,Optional,Unionfrommetatensor.torch.atomisticimportModelMetadata,is_atomistic_modelfromomegaconfimportOmegaConffrom..utils.ioimportcheck_file_extension,load_modelfrom.formatterimportCustomHelpFormatterlogger=logging.getLogger(__name__)def_add_export_model_parser(subparser:argparse._SubParsersAction)->None:"""Add `export_model` paramaters to an argparse (sub)-parser."""ifexport_model.__doc__isnotNone:description=export_model.__doc__.split(r":param")[0]else:description=None# If you change the synopsis of these commands or add new ones adjust the completion# script at `src/metatrain/share/metatrain-completion.bash`.parser=subparser.add_parser("export",description=description,formatter_class=CustomHelpFormatter,)parser.set_defaults(callable="export_model")parser.add_argument("path",type=str,help=("Saved model which should be exported. Path can be either a URL or a ""local file."),)parser.add_argument("-o","--output",dest="output",type=str,required=False,help=("Filename of the exported model (default: <stem>.pt, ""where <stem> is the name of the checkpoint without the extension)."),)parser.add_argument("-m","--metadata",type=str,required=False,dest="metadata",default=None,help="Metatdata YAML file to be appended to the model.",)parser.add_argument("--token",dest="token",type=str,required=False,default=None,help="HuggingFace API token to download (private )models from HuggingFace. ""You can also set a environment variable `HF_TOKEN` to avoid passing it every ""time.",)def_prepare_export_model_args(args:argparse.Namespace)->None:"""Prepare arguments for export_model."""path=args.__dict__.pop("path")token=args.__dict__.pop("token")# use env variable if availableenv_token=os.environ.get("HF_TOKEN")ifenv_token:iftokenisNone:token=env_tokenelse:raiseValueError("Both CLI and environment variable tokens are set for HuggingFace. ""Please use only one.")args.model=load_model(path=path,token=token)ifargs.metadataisnotNone:args.metadata=ModelMetadata(**OmegaConf.load(args.metadata))# only these are needed for `export_model``keys_to_keep=["model","output","metadata"]original_keys=list(args.__dict__.keys())forkeyinoriginal_keys:ifkeynotinkeys_to_keep:args.__dict__.pop(key)ifargs.__dict__.get("output")isNone:args.__dict__["output"]=Path(path).stem+".pt"
[docs]defexport_model(model:Any,output:Union[Path,str],metadata:Optional[ModelMetadata]=None)->None:"""Export a trained model allowing it to make predictions. This includes predictions within molecular simulation engines. Exported models will be saved with a ``.pt`` file ending. If ``path`` does not end with this file extensions ``.pt`` will be added and a warning emitted. :param model: model to be exported :param output: path to save the model :param metadata: metadata to be appended to the model """path=str(Path(check_file_extension(filename=output,extension=".pt")).absolute().resolve())extensions_path=str(Path("extensions/").absolute().resolve())ifnotis_atomistic_model(model):model=model.export(metadata)model.save(path,collect_extensions=extensions_path)logger.info(f"Model exported to '{path}' and extensions to '{extensions_path}'")