Update bin2safetensors/convert.py (#2)
Browse files- Update bin2safetensors/convert.py (62bc6924b6d614a9e1b4ad5f3f1dd90b06573eab)
Co-authored-by: None <[email protected]>
- bin2safetensors/convert.py +20 -2
bin2safetensors/convert.py
CHANGED
@@ -312,7 +312,7 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitIn
|
|
312 |
return new_pr, errors
|
313 |
|
314 |
|
315 |
-
def main(input_directory, output_directory):
|
316 |
# Get a list of all files in the input directory
|
317 |
files = os.listdir(input_directory)
|
318 |
|
@@ -360,11 +360,29 @@ def main(input_directory, output_directory):
|
|
360 |
output_filename = os.path.join(output_directory, f"model-{i:05d}-of-{yyyyy:05d}.safetensors")
|
361 |
convert_file(input_filename, output_filename)
|
362 |
print(f"Converted {input_filename} to {output_filename}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
|
364 |
if __name__ == "__main__":
|
365 |
parser = argparse.ArgumentParser(description="Convert pytorch_model model to safetensor and copy JSON and .model files.")
|
366 |
parser.add_argument("input_directory", help="Path to the input directory containing pytorch_model files")
|
367 |
parser.add_argument("output_directory", help="Path to the output directory for converted safetensor files")
|
|
|
|
|
368 |
args = parser.parse_args()
|
369 |
|
370 |
-
main(args.input_directory, args.output_directory)
|
|
|
312 |
return new_pr, errors
|
313 |
|
314 |
|
315 |
+
def main(input_directory, output_directory, delete_files, delete_input_directory):
|
316 |
# Get a list of all files in the input directory
|
317 |
files = os.listdir(input_directory)
|
318 |
|
|
|
360 |
output_filename = os.path.join(output_directory, f"model-{i:05d}-of-{yyyyy:05d}.safetensors")
|
361 |
convert_file(input_filename, output_filename)
|
362 |
print(f"Converted {input_filename} to {output_filename}")
|
363 |
+
|
364 |
+
# Delete the pytorch_model file if the delete_files flag or delete_input_directory flag are set
|
365 |
+
if delete_files or delete_input_directory:
|
366 |
+
os.remove(input_filename)
|
367 |
+
print(f"Deleted {input_filename}")
|
368 |
+
|
369 |
+
# Check if there are any remaining pytorch_model files in the input directory
|
370 |
+
remaining_model_files = [file for file in os.listdir(input_directory) if re.match(r'pytorch_model-\d{5}-of-\d{5}\.bin', file)]
|
371 |
+
|
372 |
+
if len(remaining_model_files) == 0:
|
373 |
+
# Delete the input directory if all files have been converted successfully
|
374 |
+
if delete_input_directory:
|
375 |
+
shutil.rmtree(input_directory)
|
376 |
+
print(f"Deleted input directory {input_directory}")
|
377 |
+
else:
|
378 |
+
print("Warning: Input directory still contains pytorch_model files and won't be deleted.")
|
379 |
|
380 |
if __name__ == "__main__":
|
381 |
parser = argparse.ArgumentParser(description="Convert pytorch_model model to safetensor and copy JSON and .model files.")
|
382 |
parser.add_argument("input_directory", help="Path to the input directory containing pytorch_model files")
|
383 |
parser.add_argument("output_directory", help="Path to the output directory for converted safetensor files")
|
384 |
+
parser.add_argument("-d", "--delete", action="store_true", help="Delete pytorch_model files after conversion")
|
385 |
+
parser.add_argument("-D", "--delete-input", action="store_true", help="Delete pytorch_model files after conversion as well as the input directory after all files are converted")
|
386 |
args = parser.parse_args()
|
387 |
|
388 |
+
main(args.input_directory, args.output_directory, args.delete, args.delete_input)
|