Implement force download parameter in model insert_defaults methods

This commit is contained in:
Patrick Jentsch 2023-11-28 12:10:55 +01:00
parent bdcc80a66f
commit 9c22370eea

View File

@ -953,7 +953,7 @@ class TesseractOCRPipelineModel(FileMixin, HashidMixin, db.Model):
return self.user.hashid return self.user.hashid
@staticmethod @staticmethod
def insert_defaults(): def insert_defaults(force_download=False):
nopaque_user = User.query.filter_by(username='nopaque').first() nopaque_user = User.query.filter_by(username='nopaque').first()
defaults_file = os.path.join( defaults_file = os.path.join(
os.path.dirname(os.path.abspath(__file__)), os.path.dirname(os.path.abspath(__file__)),
@ -966,6 +966,7 @@ class TesseractOCRPipelineModel(FileMixin, HashidMixin, db.Model):
if model is not None: if model is not None:
model.compatible_service_versions = m['compatible_service_versions'] model.compatible_service_versions = m['compatible_service_versions']
model.description = m['description'] model.description = m['description']
model.filename = f'{model.id}.traineddata'
model.publisher = m['publisher'] model.publisher = m['publisher']
model.publisher_url = m['publisher_url'] model.publisher_url = m['publisher_url']
model.publishing_url = m['publishing_url'] model.publishing_url = m['publishing_url']
@ -973,38 +974,39 @@ class TesseractOCRPipelineModel(FileMixin, HashidMixin, db.Model):
model.is_public = True model.is_public = True
model.title = m['title'] model.title = m['title']
model.version = m['version'] model.version = m['version']
continue else:
model = TesseractOCRPipelineModel( model = TesseractOCRPipelineModel(
compatible_service_versions=m['compatible_service_versions'], compatible_service_versions=m['compatible_service_versions'],
description=m['description'], description=m['description'],
publisher=m['publisher'], publisher=m['publisher'],
publisher_url=m['publisher_url'], publisher_url=m['publisher_url'],
publishing_url=m['publishing_url'], publishing_url=m['publishing_url'],
publishing_year=m['publishing_year'], publishing_year=m['publishing_year'],
is_public=True, is_public=True,
title=m['title'], title=m['title'],
user=nopaque_user, user=nopaque_user,
version=m['version'] version=m['version']
) )
db.session.add(model) db.session.add(model)
db.session.flush(objects=[model]) db.session.flush(objects=[model])
db.session.refresh(model) db.session.refresh(model)
model.filename = f'{model.id}.traineddata' model.filename = f'{model.id}.traineddata'
r = requests.get(m['url'], stream=True) if not os.path.exists(model.path) or force_download:
pbar = tqdm( r = requests.get(m['url'], stream=True)
desc=f'{model.title} ({model.filename})', pbar = tqdm(
unit="B", desc=f'{model.title} ({model.filename})',
unit_scale=True, unit="B",
unit_divisor=1024, unit_scale=True,
total=int(r.headers['Content-Length']) unit_divisor=1024,
) total=int(r.headers['Content-Length'])
pbar.clear() )
with open(model.path, 'wb') as f: pbar.clear()
for chunk in r.iter_content(chunk_size=1024): with open(model.path, 'wb') as f:
if chunk: # filter out keep-alive new chunks for chunk in r.iter_content(chunk_size=1024):
pbar.update(len(chunk)) if chunk: # filter out keep-alive new chunks
f.write(chunk) pbar.update(len(chunk))
pbar.close() f.write(chunk)
pbar.close()
db.session.commit() db.session.commit()
def delete(self): def delete(self):
@ -1080,7 +1082,7 @@ class SpaCyNLPPipelineModel(FileMixin, HashidMixin, db.Model):
return self.user.hashid return self.user.hashid
@staticmethod @staticmethod
def insert_defaults(): def insert_defaults(force_download=False):
nopaque_user = User.query.filter_by(username='nopaque').first() nopaque_user = User.query.filter_by(username='nopaque').first()
defaults_file = os.path.join( defaults_file = os.path.join(
os.path.dirname(os.path.abspath(__file__)), os.path.dirname(os.path.abspath(__file__)),
@ -1093,6 +1095,7 @@ class SpaCyNLPPipelineModel(FileMixin, HashidMixin, db.Model):
if model is not None: if model is not None:
model.compatible_service_versions = m['compatible_service_versions'] model.compatible_service_versions = m['compatible_service_versions']
model.description = m['description'] model.description = m['description']
model.filename = m['url'].split('/')[-1]
model.publisher = m['publisher'] model.publisher = m['publisher']
model.publisher_url = m['publisher_url'] model.publisher_url = m['publisher_url']
model.publishing_url = m['publishing_url'] model.publishing_url = m['publishing_url']
@ -1101,39 +1104,40 @@ class SpaCyNLPPipelineModel(FileMixin, HashidMixin, db.Model):
model.title = m['title'] model.title = m['title']
model.version = m['version'] model.version = m['version']
model.pipeline_name = m['pipeline_name'] model.pipeline_name = m['pipeline_name']
continue else:
model = SpaCyNLPPipelineModel( model = SpaCyNLPPipelineModel(
compatible_service_versions=m['compatible_service_versions'], compatible_service_versions=m['compatible_service_versions'],
description=m['description'], description=m['description'],
publisher=m['publisher'], filename=m['url'].split('/')[-1],
publisher_url=m['publisher_url'], publisher=m['publisher'],
publishing_url=m['publishing_url'], publisher_url=m['publisher_url'],
publishing_year=m['publishing_year'], publishing_url=m['publishing_url'],
is_public=True, publishing_year=m['publishing_year'],
title=m['title'], is_public=True,
user=nopaque_user, title=m['title'],
version=m['version'], user=nopaque_user,
pipeline_name=m['pipeline_name'] version=m['version'],
) pipeline_name=m['pipeline_name']
db.session.add(model) )
db.session.flush(objects=[model]) db.session.add(model)
db.session.refresh(model) db.session.flush(objects=[model])
model.filename = m['url'].split('/')[-1] db.session.refresh(model)
r = requests.get(m['url'], stream=True) if not os.path.exists(model.path) or force_download:
pbar = tqdm( r = requests.get(m['url'], stream=True)
desc=f'{model.title} ({model.filename})', pbar = tqdm(
unit="B", desc=f'{model.title} ({model.filename})',
unit_scale=True, unit="B",
unit_divisor=1024, unit_scale=True,
total=int(r.headers['Content-Length']) unit_divisor=1024,
) total=int(r.headers['Content-Length'])
pbar.clear() )
with open(model.path, 'wb') as f: pbar.clear()
for chunk in r.iter_content(chunk_size=1024): with open(model.path, 'wb') as f:
if chunk: # filter out keep-alive new chunks for chunk in r.iter_content(chunk_size=1024):
pbar.update(len(chunk)) if chunk: # filter out keep-alive new chunks
f.write(chunk) pbar.update(len(chunk))
pbar.close() f.write(chunk)
pbar.close()
db.session.commit() db.session.commit()
def delete(self): def delete(self):