Med-PaLM M - Google develops doctor AI 😷, does it work?

Milos Vukadinovic,paper reviews

Review of the paper: Towards Generalist Biomedical AI (opens in a new tab)

What? The authors developed MedPalm-M, the first generalist biomedical AI system, and introduced MultiMedBench dataset for benchmarking biomedical generalist models.

Why? Medical diagnostics is naturally multi-modal, therefore multimodal generalist AI model could improve diagnostics quality.

How? By instruction tuning of Palm-E model that can process language, vision and sensor data.

And?

Introduction

Medicine is naturally multimodal and diagnostics is hard. We already have many specialized medical ML models that can generate radiology reports, predict protein folding, discover drugs and solve many other tasks. Recently, foundation models gained traction, so why not attempt to make a foundation medical AI that can solve all tasks equally well as specialized models. That's exactly what paper did: "Let's a generalist/foundation model for Biomedicine"

ℹ️

Foundation model is any model that is trained on broad data that can be adapted to a wide range of downstream tasks. Generalist model is a model with one set of weights that can excel in a wide-variety of tasks without the need for fine-tuning.

ℹ️

AI alignment research aims to steer AI systems towards humans' intended goals, preferences or ethical principles.

Goal: make a generalist Biomedical AI model to work as primary care physician. Generalist AI model should be able to give you primary advice, and then point you to the specialized model for your problem.

MultiMedBench

There is no real benchmark for generalist Biomedical AI agents, and that's why MultiMedBench is introduced. It consists of:

It measures the capability of a general-purpose biomedical AI to perform a variety of clinically-relevant tasks.

Med-PaLM M

PaLM (Pathways Language Model) is a large language model released by Google that their Bard chatbot is based on. Additionally, Google released PaLM-E, a model built on top of PaLM and Vision Transformer that can process not only text, but images and sensor data too. It does that by using Vision Transformer (ViT) to get image encodings and concatenate it with text encodings before feeding it to the PaLM architecture. Finally, Med-PaLM M is a version of PaLM-E finetuned on biomedical data from MultiMedBench dataset.

The authors used fine-tuning method called Instruction tuning. Namely, they create a specific format for each type of task in MultiMedBench dataset. All Med-PaLM M weights were unfrozen and updated using gradient descent. The dataset,that weights are finetuned on, is a set of prompts where each prompt consists of the following sections: Instruction, Relevant context information, Example, and Question.

Instruction: The authors created 14 different instruction texts, for each different task. This part is basically explaining the rules of the game to the LM.

Relevant context information: Here we give all the relevant data that LM can analyze to answer the question.

Example: It's useful to give an example of the expectied interaction to the LM. Similarly, to when a person is introduced to a new game, an example is very beneficial for learning.

Question: We ask the question about the context data provided.

The labels are answers to questions, that we can take from datasets. Also, in test phase, input data is restructured according to this prompt template.

Photo

For example, in the image above, the instruction "you are a helpful radiology assistant" is the same for each prompt of the task "radiology: visual question answering". Context is auto populated using the information from the database, an example Question-Answer pair is given. It's worth noting that the example pair does not contain image emeddings. At the end, a real question is asked, and the model is supposed to complete the answer

In my opinion, the main technical concept in this paper is using prompt engineering in an interesting way, not only during inference but also to fine-tune on prompts.

ℹ️

Tangent Alert

When doing prompt engineeering in general there are two methods that will superboost your LLM outputs: chain of thought and few shot in context learning (i.e. giving 2 examples). Here is my interaction with chatGPT that demonstrates the benefits of these two methods.

Now, after including chain of thought and examples:

Results

I noticed that this study has excellent evaluation design, which they claim is created to evaluate generalist capabilities, emergent capabilities and report generation quality. I describe each evaluation method before I talk about the results.

First, they evaluate model's performance on MultiMedBench dataset. What stands out is that Med-PaLM M achieved SOTA on most visual question answering, report generation and image classification tasks.

Then, they claim that the model can achieve zero shot generalization by demonstrating good performance on tubercolosis classification. This is zero shot because there was no explicit task where Med-PaLM M was asked to classify tubercolosis during training. So I can imagine that during inference they created a prompt like this "Is tubercolosis present in the following chest X-ray? Answer yes or no only."

The third method of evaluation, was having radiologists assess quality of synthetic radiology reports. In most cases radiologists were able to differentiate between a real and a fake report, but synthetic reports showed some sign of quality. The most imporant conclusion of this evaluation is that 84 billion paramter model performs the best.

Conclusion and opinion

This paper presented an instruction tuning approach to fine-tuning pretrained foundation models. As a results, they made an important step in research towards developing a generalist biomedical AI.

Limitation