diff --git a/app/models.py b/app/models.py index 91477f92..ba90ca08 100644 --- a/app/models.py +++ b/app/models.py @@ -953,7 +953,7 @@ class TesseractOCRPipelineModel(FileMixin, HashidMixin, db.Model): return self.user.hashid @staticmethod - def insert_defaults(): + def insert_defaults(force_download=False): nopaque_user = User.query.filter_by(username='nopaque').first() defaults_file = os.path.join( os.path.dirname(os.path.abspath(__file__)), @@ -966,6 +966,7 @@ class TesseractOCRPipelineModel(FileMixin, HashidMixin, db.Model): if model is not None: model.compatible_service_versions = m['compatible_service_versions'] model.description = m['description'] + model.filename = f'{model.id}.traineddata' model.publisher = m['publisher'] model.publisher_url = m['publisher_url'] model.publishing_url = m['publishing_url'] @@ -973,38 +974,39 @@ class TesseractOCRPipelineModel(FileMixin, HashidMixin, db.Model): model.is_public = True model.title = m['title'] model.version = m['version'] - continue - model = TesseractOCRPipelineModel( - compatible_service_versions=m['compatible_service_versions'], - description=m['description'], - publisher=m['publisher'], - publisher_url=m['publisher_url'], - publishing_url=m['publishing_url'], - publishing_year=m['publishing_year'], - is_public=True, - title=m['title'], - user=nopaque_user, - version=m['version'] - ) - db.session.add(model) - db.session.flush(objects=[model]) - db.session.refresh(model) - model.filename = f'{model.id}.traineddata' - r = requests.get(m['url'], stream=True) - pbar = tqdm( - desc=f'{model.title} ({model.filename})', - unit="B", - unit_scale=True, - unit_divisor=1024, - total=int(r.headers['Content-Length']) - ) - pbar.clear() - with open(model.path, 'wb') as f: - for chunk in r.iter_content(chunk_size=1024): - if chunk: # filter out keep-alive new chunks - pbar.update(len(chunk)) - f.write(chunk) - pbar.close() + else: + model = TesseractOCRPipelineModel( + compatible_service_versions=m['compatible_service_versions'], + description=m['description'], + publisher=m['publisher'], + publisher_url=m['publisher_url'], + publishing_url=m['publishing_url'], + publishing_year=m['publishing_year'], + is_public=True, + title=m['title'], + user=nopaque_user, + version=m['version'] + ) + db.session.add(model) + db.session.flush(objects=[model]) + db.session.refresh(model) + model.filename = f'{model.id}.traineddata' + if not os.path.exists(model.path) or force_download: + r = requests.get(m['url'], stream=True) + pbar = tqdm( + desc=f'{model.title} ({model.filename})', + unit="B", + unit_scale=True, + unit_divisor=1024, + total=int(r.headers['Content-Length']) + ) + pbar.clear() + with open(model.path, 'wb') as f: + for chunk in r.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + pbar.update(len(chunk)) + f.write(chunk) + pbar.close() db.session.commit() def delete(self): @@ -1080,7 +1082,7 @@ class SpaCyNLPPipelineModel(FileMixin, HashidMixin, db.Model): return self.user.hashid @staticmethod - def insert_defaults(): + def insert_defaults(force_download=False): nopaque_user = User.query.filter_by(username='nopaque').first() defaults_file = os.path.join( os.path.dirname(os.path.abspath(__file__)), @@ -1093,6 +1095,7 @@ class SpaCyNLPPipelineModel(FileMixin, HashidMixin, db.Model): if model is not None: model.compatible_service_versions = m['compatible_service_versions'] model.description = m['description'] + model.filename = m['url'].split('/')[-1] model.publisher = m['publisher'] model.publisher_url = m['publisher_url'] model.publishing_url = m['publishing_url'] @@ -1101,39 +1104,40 @@ class SpaCyNLPPipelineModel(FileMixin, HashidMixin, db.Model): model.title = m['title'] model.version = m['version'] model.pipeline_name = m['pipeline_name'] - continue - model = SpaCyNLPPipelineModel( - compatible_service_versions=m['compatible_service_versions'], - description=m['description'], - publisher=m['publisher'], - publisher_url=m['publisher_url'], - publishing_url=m['publishing_url'], - publishing_year=m['publishing_year'], - is_public=True, - title=m['title'], - user=nopaque_user, - version=m['version'], - pipeline_name=m['pipeline_name'] - ) - db.session.add(model) - db.session.flush(objects=[model]) - db.session.refresh(model) - model.filename = m['url'].split('/')[-1] - r = requests.get(m['url'], stream=True) - pbar = tqdm( - desc=f'{model.title} ({model.filename})', - unit="B", - unit_scale=True, - unit_divisor=1024, - total=int(r.headers['Content-Length']) - ) - pbar.clear() - with open(model.path, 'wb') as f: - for chunk in r.iter_content(chunk_size=1024): - if chunk: # filter out keep-alive new chunks - pbar.update(len(chunk)) - f.write(chunk) - pbar.close() + else: + model = SpaCyNLPPipelineModel( + compatible_service_versions=m['compatible_service_versions'], + description=m['description'], + filename=m['url'].split('/')[-1], + publisher=m['publisher'], + publisher_url=m['publisher_url'], + publishing_url=m['publishing_url'], + publishing_year=m['publishing_year'], + is_public=True, + title=m['title'], + user=nopaque_user, + version=m['version'], + pipeline_name=m['pipeline_name'] + ) + db.session.add(model) + db.session.flush(objects=[model]) + db.session.refresh(model) + if not os.path.exists(model.path) or force_download: + r = requests.get(m['url'], stream=True) + pbar = tqdm( + desc=f'{model.title} ({model.filename})', + unit="B", + unit_scale=True, + unit_divisor=1024, + total=int(r.headers['Content-Length']) + ) + pbar.clear() + with open(model.path, 'wb') as f: + for chunk in r.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + pbar.update(len(chunk)) + f.write(chunk) + pbar.close() db.session.commit() def delete(self):