ner_fcn.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. from google.cloud import datastore
  2. from google.oauth2 import service_account
  3. import logging
  4. import re
  5. import os
  6. import en_core_sci_sm, en_core_sci_lg, en_ner_bionlp13cg_md
  7. from scispacy.umls_linking import UmlsEntityLinker
  8. from scispacy.abbreviation import AbbreviationDetector
  9. # DEVELOPER: change path to key
  10. project_id = os.getenv('PROJECT_ID')
  11. bucket_name = os.getenv('BUCKET_NAME')
  12. location = os.getenv('LOCATION')
  13. key_path = os.getenv('SA_KEY_PATH')
  14. credentials = service_account.Credentials.from_service_account_file(key_path)
  15. datastore_client = datastore.Client(credentials=credentials,
  16. project_id=credentials.project_id)
  17. def loadModel(model=en_core_sci_lg):
  18. """
  19. Loading Named Entity Recognition model.
  20. Args:
  21. model: options: en_core_sci_sm, en_core_sci_lg, en_ner_bionlp13cg_md
  22. Returns:
  23. nlp: loaded model
  24. """
  25. # Load the model
  26. nlp = model.load()
  27. # Add pipe features to pipeline
  28. linker = UmlsEntityLinker(resolve_abbreviations=True)
  29. nlp.add_pipe(linker)
  30. # Add the abbreviation pipe to the spacy pipeline.
  31. abbreviation_pipe = AbbreviationDetector(nlp)
  32. nlp.add_pipe(abbreviation_pipe)
  33. logging.info("Model and add-ons successfully loaded.")
  34. return nlp
  35. def extractMedEntities(vectorized_doc):
  36. """
  37. Returns UMLS entities contained in a text.
  38. Args:
  39. vectorized_doc:
  40. Returns:
  41. UMLS_tuis_entity: dict - key: entity and value: TUI code
  42. """
  43. # Pattern for TUI code
  44. pattern = 'T(\d{3})'
  45. UMLS_tuis_entity = {}
  46. entity_dict = {}
  47. linker = UmlsEntityLinker(resolve_abbreviations=True)
  48. for idx in range(len(vectorized_doc.ents)):
  49. entity = vectorized_doc.ents[idx]
  50. entity_dict[entity] = ''
  51. for umls_ent in entity._.umls_ents:
  52. entity_dict[entity] = linker.umls.cui_to_entity[umls_ent[0]]
  53. # RegEx expression if contains TUI code
  54. tui = re.search(pattern, str(entity_dict[entity]))
  55. if tui:
  56. UMLS_tuis_entity[str(entity)] = tui.group()
  57. else:
  58. UMLS_tuis_entity[str(entity)] = None
  59. return UMLS_tuis_entity
  60. def addTask(client, doc_title, entities_dict):
  61. """
  62. Upload entities to Datastore.
  63. Args:
  64. client:
  65. doc_title:
  66. entities_dict:
  67. Returns:
  68. Datastore key object.
  69. """
  70. key = client.key('case', doc_title)
  71. task = datastore.Entity(key=key)
  72. task.update(
  73. entities_dict
  74. )
  75. client.put(task)
  76. # Then get by key for this entity
  77. logging.info("Uploaded {} to Datastore.".format(doc_title))
  78. return client.get(key)
  79. def getCases(datastore_client, filter_dict, limit=10):
  80. """
  81. Get results of query with custom filters
  82. Args:
  83. datastore_client: Client object
  84. filter_dict: dict - e.g {parameter_A: [entity_name_A, entity_name_B],
  85. parameter_B: [entitiy_name_C]
  86. }
  87. limit: int - result limits per default 10
  88. Returns:
  89. results: list - query results
  90. """
  91. query = datastore_client.query(kind='case')
  92. for key, values in filter_dict.items():
  93. for value in values:
  94. query.add_filter(key, '=', value)
  95. results = list(query.fetch(limit=limit))
  96. return results