ner_fcn.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. from google.cloud import datastore
  2. from google.oauth2 import service_account
  3. import logging
  4. import re
  5. import os
  6. import en_core_sci_sm, en_core_sci_lg, en_ner_bionlp13cg_md
  7. from scispacy.umls_linking import UmlsEntityLinker
  8. from scispacy.abbreviation import AbbreviationDetector
  9. # DEVELOPER: change path to key
  10. # project_id = os.getenv('PROJECT_ID')
  11. # bucket_name = os.getenv('BUCKET_NAME')
  12. # location = os.getenv('LOCATION')
  13. # key_path = os.getenv('SA_KEY_PATH')
  14. # credentials = service_account.Credentials.from_service_account_file(key_path)
  15. #
  16. # datastore_client = datastore.Client(credentials=credentials,
  17. # project_id=credentials.project_id)
  18. def loadModel(model=en_core_sci_lg):
  19. """
  20. Loading Named Entity Recognition model.
  21. Args:
  22. model: options: en_core_sci_sm, en_core_sci_lg, en_ner_bionlp13cg_md
  23. Returns:
  24. nlp: loaded model
  25. """
  26. # Load the model
  27. nlp = model.load()
  28. # Add pipe features to pipeline
  29. linker = UmlsEntityLinker(resolve_abbreviations=True)
  30. nlp.add_pipe(linker)
  31. # Add the abbreviation pipe to the spacy pipeline.
  32. abbreviation_pipe = AbbreviationDetector(nlp)
  33. nlp.add_pipe(abbreviation_pipe)
  34. logging.info("Model and add-ons successfully loaded.")
  35. return nlp
  36. def extractMedEntities(vectorized_doc):
  37. """
  38. Returns UMLS entities contained in a text.
  39. Args:
  40. vectorized_doc:
  41. Returns:
  42. UMLS_tuis_entity: dict - key: entity and value: TUI code
  43. """
  44. # Pattern for TUI code
  45. pattern = 'T(\d{3})'
  46. UMLS_tuis_entity = {}
  47. entity_dict = {}
  48. linker = UmlsEntityLinker(resolve_abbreviations=True)
  49. for idx in range(len(vectorized_doc.ents)):
  50. entity = vectorized_doc.ents[idx]
  51. entity_dict[entity] = ''
  52. for umls_ent in entity._.umls_ents:
  53. entity_dict[entity] = linker.umls.cui_to_entity[umls_ent[0]]
  54. # RegEx expression if contains TUI code
  55. tui = re.search(pattern, str(entity_dict[entity]))
  56. if tui:
  57. UMLS_tuis_entity[str(entity)] = tui.group()
  58. else:
  59. UMLS_tuis_entity[str(entity)] = None
  60. return UMLS_tuis_entity
  61. def addTask(client, doc_title, entities_dict):
  62. """
  63. Upload entities to Datastore.
  64. Args:
  65. client:
  66. doc_title:
  67. entities_dict:
  68. Returns:
  69. Datastore key object.
  70. """
  71. key = client.key('case', doc_title)
  72. task = datastore.Entity(key=key)
  73. task.update(
  74. entities_dict
  75. )
  76. client.put(task)
  77. # Then get by key for this entity
  78. logging.info("Uploaded {} to Datastore.".format(doc_title))
  79. return client.get(key)
  80. def getCases(datastore_client, filter_dict, limit=10):
  81. """
  82. Get results of query with custom filters
  83. Args:
  84. datastore_client: Client object
  85. filter_dict: dict - e.g {parameter_A: [entity_name_A, entity_name_B],
  86. parameter_B: [entitiy_name_C]
  87. }
  88. limit: int - result limits per default 10
  89. Returns:
  90. results: list - query results
  91. """
  92. query = datastore_client.query(kind='case')
  93. for key, values in filter_dict.items():
  94. for value in values:
  95. query.add_filter(key, '=', value)
  96. results = list(query.fetch(limit=limit))
  97. return results