-
Notifications
You must be signed in to change notification settings - Fork 68
Added a script to restore punctuation #38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Paul-HenriBJT
wants to merge
18
commits into
huggingface:main
Choose a base branch
from
Paul-HenriBJT:punctuation-script
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
f880b3a
added a script to restore punctuation
3adfe18
updated requirements.txt with deepmultilingualpunctuation
a3c5501
added true casing
Paul-HenriBJT d85bd6d
updated handling of language
Paul-HenriBJT d25b7d2
updated requirements.txt
Paul-HenriBJT 9b7ddb9
removed unecessary languages
Paul-HenriBJT 289212c
requirements test
Paul-HenriBJT b21b27c
requirements test
Paul-HenriBJT c8c447c
update requirements
Paul-HenriBJT ba519c8
update requirements
Paul-HenriBJT 13e314c
update requirements
Paul-HenriBJT a1a0e79
update requirements
Paul-HenriBJT e77e9e0
update requirements
Paul-HenriBJT 0d8ed85
Added true casing for 8 languages
Paul-HenriBJT 018a3d3
added processing of the text_description column
Paul-HenriBJT df6e6ed
Merge branch 'casing' into punctuation-script
Paul-HenriBJT e3bf102
moved the dowload of the spacy models inside the restore punctuation …
Paul-HenriBJT 9156758
added num_proc=args.num_proc
Paul-HenriBJT File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,4 +5,6 @@ g2p | |
| demucs | ||
| transformers | ||
| accelerate | ||
| bitsandbytes | ||
| bitsandbytes | ||
| deepmultilingualpunctuation | ||
| spacy | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| from datasets import load_dataset | ||
| import argparse | ||
| import re | ||
|
|
||
| def capitalize_first_letter(text): | ||
| return '. '.join(sentence.capitalize() for sentence in text.split('. ')) | ||
|
|
||
| def capitalize_words_remove_quotes(text): | ||
| # Remove quotes and capitalize each word | ||
| return ' '.join(word.capitalize() for word in re.findall(r'\w+', text)) | ||
|
|
||
| def apply_recasing(examples, text_column, description_column): | ||
| recased_texts = [capitalize_first_letter(text) for text in examples[text_column]] | ||
| recased_descriptions = [capitalize_words_remove_quotes(desc) for desc in examples[description_column]] | ||
| return { | ||
| f"original_{text_column}": examples[text_column], | ||
| text_column: recased_texts, | ||
| f"original_{description_column}": examples[description_column], | ||
| description_column: recased_descriptions | ||
| } | ||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser() | ||
|
|
||
| parser.add_argument("dataset_name", type=str, help="Path or name of the dataset.") | ||
| parser.add_argument("--configuration", default=None, type=str, help="Dataset configuration to use, if necessary.") | ||
| parser.add_argument("--output_dir", default=None, type=str, help="If specified, save the dataset on disk with this path.") | ||
| parser.add_argument("--repo_id", default=None, type=str, help="If specified, push the dataset to the hub.") | ||
| parser.add_argument("--text_column", default="text", type=str, help="Name of the column containing the text to be recased.") | ||
| parser.add_argument("--description_column", default="text_description", type=str, help="Name of the column containing the description to be recased and cleaned.") | ||
| parser.add_argument("--batch_size", default=32, type=int, help="Batch size for processing.") | ||
| parser.add_argument("--num_proc", default=1, type=int, help="Number of processes to use.") | ||
|
|
||
| args = parser.parse_args() | ||
|
|
||
| if args.configuration: | ||
| dataset = load_dataset(args.dataset_name, args.configuration, num_proc=args.num_proc) | ||
| else: | ||
| dataset = load_dataset(args.dataset_name, num_proc=args.num_proc) | ||
|
|
||
| recased_dataset = dataset.map( | ||
| apply_recasing, | ||
| batched=True, | ||
| num_proc=args.num_proc, | ||
| batch_size=args.batch_size, | ||
| fn_kwargs={"text_column": args.text_column, "description_column": args.description_column}, | ||
| desc="Applying recasing" | ||
| ) | ||
|
|
||
| if args.output_dir: | ||
| print("Saving to disk...") | ||
| recased_dataset.save_to_disk(args.output_dir) | ||
|
|
||
| if args.repo_id: | ||
| print("Pushing to the hub...") | ||
| if args.configuration: | ||
| recased_dataset.push_to_hub(args.repo_id, args.configuration) | ||
| else: | ||
| recased_dataset.push_to_hub(args.repo_id) | ||
|
|
||
| print("Recasing completed.") | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,210 @@ | ||
| import argparse | ||
| from multiprocessing import set_start_method | ||
| from datasets import load_dataset | ||
| from deepmultilingualpunctuation import PunctuationModel | ||
| import spacy | ||
| from typing import Dict, Callable, List | ||
| import re | ||
| from spacy.cli import download | ||
|
|
||
| nlp_models: Dict[str, spacy.language.Language] = {} | ||
|
|
||
| def load_spacy_model(lang_code: str) -> spacy.language.Language: | ||
| """Load and return the appropriate spaCy model for the given language code. | ||
| Downloads the model if not already installed.""" | ||
| if lang_code not in nlp_models: | ||
| model_name = { | ||
| 'ca': 'ca_core_news_sm', | ||
| 'en': 'en_core_web_sm', | ||
| 'de': 'de_core_news_sm', | ||
| 'fr': 'fr_core_news_sm', | ||
| 'es': 'es_core_news_sm', | ||
| 'it': 'it_core_news_sm', | ||
| 'pl': 'pl_core_news_sm', | ||
| 'nl': 'nl_core_news_sm', | ||
| 'pt': 'pt_core_news_sm', | ||
| }.get(lang_code) | ||
|
|
||
| if model_name is None: | ||
| raise ValueError(f"Unsupported language code: {lang_code}") | ||
|
|
||
| try: | ||
| nlp_models[lang_code] = spacy.load(model_name) | ||
| except OSError: | ||
| print(f"Downloading {model_name}...") | ||
| download(model_name) | ||
| nlp_models[lang_code] = spacy.load(model_name) | ||
|
|
||
| return nlp_models[lang_code] | ||
|
|
||
| def get_capitalization_function(lang_code: str) -> Callable[[spacy.tokens.Token], str]: | ||
| """Return the appropriate capitalization function for the given language.""" | ||
|
|
||
| def default_capitalization(token: spacy.tokens.Token) -> str: | ||
| if token.is_sent_start or token.pos_ in ('PROPN', 'NNP', 'NNPS'): | ||
| return token.text.capitalize() | ||
| return token.text.lower() | ||
|
|
||
| def german_capitalization(token: spacy.tokens.Token) -> str: | ||
| if token.is_sent_start or token.pos_ in ('PROPN', 'NOUN'): | ||
| return token.text.capitalize() | ||
| return token.text.lower() | ||
|
|
||
| if lang_code == 'de': | ||
| return german_capitalization | ||
| else: | ||
| return default_capitalization | ||
|
|
||
| def true_case(text: str, lang_code: str) -> str: | ||
| """ | ||
| Perform true casing on the input text for the specified language. | ||
|
|
||
| :param text: Input text to be true cased | ||
| :param lang_code: Two-letter language code (e.g., 'en' for English) | ||
| :return: True cased text | ||
| """ | ||
| nlp = load_spacy_model(lang_code) | ||
| capitalization_func = get_capitalization_function(lang_code) | ||
|
|
||
| doc = nlp(text) | ||
| true_cased_tokens = [capitalization_func(token) for token in doc] | ||
|
|
||
| # Join tokens, ensuring no space before punctuation | ||
| true_cased_text = "" | ||
| for i, token in enumerate(doc): | ||
| if i > 0 and not token.is_punct: | ||
| true_cased_text += " " | ||
| true_cased_text += true_cased_tokens[i] | ||
|
|
||
| return true_cased_text | ||
|
|
||
| def remove_quotes(text: str) -> str: | ||
| """Remove single and double quotes from the input text.""" | ||
| return re.sub(r"['\"]", "", text) | ||
|
|
||
| def apply_processing(examples, punctuation_model, text_column, description_column, lang_code, punctuation_only, truecase_only): | ||
| result = {} | ||
|
|
||
| # Process text column | ||
| if text_column: | ||
| if punctuation_only: | ||
| processed_texts = [punctuation_model.restore_punctuation(text) for text in examples[text_column]] | ||
| elif truecase_only: | ||
| processed_texts = [true_case(text, lang_code) for text in examples[text_column]] | ||
| else: | ||
| restored_texts = [punctuation_model.restore_punctuation(text) for text in examples[text_column]] | ||
| processed_texts = [true_case(text, lang_code) for text in restored_texts] | ||
|
|
||
| result[f"original_{text_column}"] = examples[text_column] | ||
| result[text_column] = processed_texts | ||
|
|
||
| # Process description column | ||
| if description_column: | ||
| if punctuation_only: | ||
| processed_descriptions = [remove_quotes(punctuation_model.restore_punctuation(text)) for text in examples[description_column]] | ||
| elif truecase_only: | ||
| processed_descriptions = [remove_quotes(true_case(text, lang_code)) for text in examples[description_column]] | ||
| else: | ||
| restored_descriptions = [punctuation_model.restore_punctuation(text) for text in examples[description_column]] | ||
| processed_descriptions = [remove_quotes(true_case(text, lang_code)) for text in restored_descriptions] | ||
|
|
||
| result[f"original_{description_column}"] = examples[description_column] | ||
| result[description_column] = processed_descriptions | ||
|
|
||
| return result | ||
|
|
||
| if __name__ == "__main__": | ||
| set_start_method("spawn") | ||
| parser = argparse.ArgumentParser() | ||
|
|
||
| parser.add_argument("dataset_name", type=str, help="Path or name of the dataset.") | ||
| parser.add_argument("--configuration", default=None, type=str, help="Dataset configuration to use, if necessary.") | ||
| parser.add_argument("--output_dir", default=None, type=str, help="If specified, save the dataset on disk with this path.") | ||
| parser.add_argument("--repo_id", default=None, type=str, help="If specified, push the dataset to the hub.") | ||
| parser.add_argument("--text_column", default="text", type=str, help="Name of the column containing the main text to be processed.") | ||
| parser.add_argument("--description_column", default="text_description", type=str, help="Name of the column containing the description text to be processed.") | ||
| parser.add_argument("--language", default=None, type=str, help="Language of the dataset. If not specified, uses the default multilingual model.") | ||
| parser.add_argument("--batch_size", default=32, type=int, help="This parameter specifies how many samples are passed by workers for operations that are using GPUs.") | ||
| parser.add_argument("--cpu_num_workers", default=1, type=int, help="Number of CPU workers for transformations that don't use GPUs or if no GPU are available.") | ||
| parser.add_argument("--punctuation_only", action="store_true", help="If set, only perform punctuation restoration.") | ||
| parser.add_argument("--truecase_only", action="store_true", help="If set, only perform true casing.") | ||
| parser.add_argument("--process_text_only", action="store_true", help="If set, only process the text column.") | ||
| parser.add_argument("--process_description_only", action="store_true", help="If set, only process the description column.") | ||
|
|
||
| args = parser.parse_args() | ||
|
|
||
| if args.punctuation_only and args.truecase_only: | ||
| raise ValueError("Cannot set both --punctuation_only and --truecase_only. Choose one or neither.") | ||
|
|
||
| if args.process_text_only and args.process_description_only: | ||
| raise ValueError("Cannot set both --process_text_only and --process_description_only. Choose one or neither.") | ||
|
|
||
| if args.configuration: | ||
| dataset = load_dataset(args.dataset_name, args.configuration, num_proc=args.cpu_num_workers) | ||
| else: | ||
| dataset = load_dataset(args.dataset_name, num_proc=args.cpu_num_workers) | ||
|
|
||
| language_to_code = { | ||
| "catalan": "ca", | ||
| "english": "en", | ||
| "german": "de", | ||
| "french": "fr", | ||
| "spanish": "es", | ||
| "italian": "it", | ||
| "polish": "pl", | ||
| "dutch": "nl", | ||
| "portuguese": "pt", | ||
| } | ||
|
|
||
| supported_languages = set(language_to_code.keys()) | ||
| if args.language and args.language.lower() not in supported_languages: | ||
| raise ValueError(f"Language {args.language} is not supported. Please choose from: {', '.join(supported_languages)}") | ||
|
|
||
| lang_code = language_to_code[args.language.lower()] if args.language else 'en' | ||
|
|
||
| if lang_code == "ca": | ||
| punctuation_model = PunctuationModel(model="softcatala/fullstop-catalan-punctuation-prediction") | ||
| elif lang_code in {"en", "it", "fr", "de", "nl"}: | ||
| punctuation_model = PunctuationModel(model="oliverguhr/fullstop-punctuation-multilingual-base") | ||
| else: | ||
| punctuation_model = PunctuationModel(model="kredor/punctuate-all") | ||
|
|
||
| # Determine which columns to process | ||
| text_column = args.text_column if not args.process_description_only else None | ||
| description_column = args.description_column if not args.process_text_only else None | ||
|
|
||
| processed_dataset = dataset.map( | ||
| apply_processing, | ||
| batched=True, | ||
| batch_size=args.batch_size, | ||
| fn_kwargs={ | ||
| "punctuation_model": punctuation_model, | ||
| "text_column": text_column, | ||
| "description_column": description_column, | ||
| "lang_code": lang_code, | ||
| "punctuation_only": args.punctuation_only, | ||
| "truecase_only": args.truecase_only | ||
| }, | ||
| desc="Processing text" | ||
| ) | ||
|
|
||
| if args.output_dir: | ||
| print("Saving to disk...") | ||
| processed_dataset.save_to_disk(args.output_dir) | ||
|
|
||
| if args.repo_id: | ||
| print("Pushing to the hub...") | ||
| if args.configuration: | ||
| processed_dataset.push_to_hub(args.repo_id, args.configuration) | ||
| else: | ||
| processed_dataset.push_to_hub(args.repo_id) | ||
|
|
||
| print("Processing completed for the following columns:", ", ".join(filter(None, [text_column, description_column]))) | ||
| if args.punctuation_only: | ||
| print("Operation: Punctuation restoration") | ||
| elif args.truecase_only: | ||
| print("Operation: True casing") | ||
| else: | ||
| print("Operation: Punctuation restoration and true casing") | ||
| if description_column: | ||
| print("Note: Single and double quotes were removed from the description column") |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.