Using adversarial networks to achieve human-level performance for chest x-ray organ segmentation
This is Part 2 of a two part series. See Part 1 for challenges and clinical applications of chest x-ray (CXR) segmentation, and how medical imaging, and CXRs specifically, critically need AI to scale.
Recap from Part 1
The task of chest X-ray (CXR) segmentation is to recognize the lung fields and the heart regions in CXRs:
Left: CXR from Japanese Society of Radiology Technology. Right: The same CXR overlaid with human labeled left lung, right lung, and heart contours.
Among a number of clinical applications, lung segmentation directly leads to a key clinical indicator cardiothoracic ratio (CTR), which leads to diagnosis of cardiomegaly.
Designing the Solution
Given the challenges in working with CXR (see Part 1), we first design the segmentation model based on Fully Convolutional Network (FCN). We then augment with adversarial training in the Structure Correcting Adversarial Network (SCAN) framework, which achieves human-level performance.
Let’s deep dive into the models and the thought processes leading to the model designs.
Segmentation with Fully Convolutional Network (FCN)
The input to the segmentation model is an image of dimension H x W x C (height, width, channels), where C = 3 for RGB values, or C = 1 for grayscale images like CXR. The model then outputs per-pixel class probability H x W x T where T is the number of classes. In our case T = 4 for [left lung, right lung, heart, background] and T=3 when heart segmentation label is not available (such as in one of the dataset).
We design the network to be fully convolutional, which replaces fully connected layers with 1×1 convolution. (See here for more details). We started off with VGG-like architecture, with about 16 weight layers and many feature maps (or convolutional channels): 64 feature maps in the first convolution, then doubling till 512 channels in the final layers. The resulted model has large capacity (>100 million parameters) that it overfits the training data perfectly, but performs poorly on the test data. This is a clear indication that our dataset is too small to support a large model like this.
Since CXR images are grayscale with standardized structures, we reduce the number of filters and find that using 8 feature maps for the first convolution, instead of 64 in VGG, gives much improved results. However we quickly runs into the model capacity limit. To increase model capacity, we go deeper. Eventually we arrive at a “skinny” and deep network with 21 weight layers:
The segmentation network architecture. The input image (400 pixels by 400 pixels) goes through convolution (in the residual blocks “Resblock”) where spacial resolution reduces but the number of feature maps (“semantic concepts”) increases. [Further tech details: the integer sequence (1, 8, 16,…) is the number of feature maps. k × k below Resblock (residual block), average pool, and deconvolution arrows indicates the receptive field sizes. The dark gray arrow denotes 5 resblocks. All convolutional layers have stride 1 × 1, while all average pooling layers have stride 2 × 2. The output is the class distribution for 4 classes (3 foreground + 1 background).]
The total number of parameters in the model is 271k, which is 500x smaller than the VGG-based segmentation models.
Performance of the Segmentation Model
Intersection over Union (IoU) is computed between the ground truth mask and the predicted segmentation
Because the model is so small (very few parameters), we can train it from scratch on just the 209 CXR examples. We use Intersection over Union (IoU) metrics to evaluate the quality of lung and heart segmentations. (See left image for a pictorial definition.) IoU ranges between 0 (no overlap between predicted mask and ground truth) to 1 (perfect match).
We use CXRs from Japanese Society of Radiological Technology (JSRT) dataset and labels from another study to prepare the JSRT dataset, consisting of 247 CXRs (209 for training and validation, 38 for evaluation). This skinny and tall segmentation network (which we call FCN for fully convolutional network) performs pretty well:
Notice that human performance is not perfect, limited by inherent subjective interpretation needed to draw out the boundaries. The low heart IoU by human observers indicates that heart boundaries are especially difficult to infer (see challenges in Part 1). This is just one of the many places where medicine isn’t an exact science.
Failure Modes of FCN
It’s often helpful to visualize what happens in the low performing samples and do a failure analysis. Below we apply our model trained on JSRT dataset to both JSRT and another dataset (which we call Montgomery):
Each column is a patient. The left two columns are patients from the JSRT evaluation set with models trained on JSRT development set. The right two columns are from the Montgomery dataset using a model trained on the full JSRT dataset only (no Montgomery data), which is a much more challenging scenario. Note that only JSRT dataset (left two columns) have heart annotations for evaluation of heart area IoU.
Aside: In the image above notice that CXRs from different dataset look quite different due to factors like different equipment, medical operators, and population. Therefore it’s a much more difficult task to adapt to a new dataset domain. Knowing that, our segmentation model already performs surprisingly well on the Montgomery dataset for the lung segmentation without ever seeing an image from that population.
These failure cases reveal the difficulties arising from CXR images’ varying contrast across samples. For example, in the image above, the apex of the ribcage of the rightmost patient’s is mistaken as an internal rib bone, resulting in the mask “bleeding out” to the black background, which has a similar intensity as the lung field. Vascular structures around mediastinum (the “white stuff” between the two lungs) and anterior rib bones (the criss-crossing lines in the lung fields) can also have similar intensity and texture as exterior boundary, resulting in the drastic mistakes as can be seen in the middle two columns.
Structure Correcting Adversarial Network (SCAN)
The failure cases tell us that the model needs to have a sense of global structures to avoid drastic failure like the earlier examples. For example, anyone with a basic training knows that the heart should be more or less elliptical, while the apex of the lung fields should be smooth and the angle where the diaphragm meets the ribcage should be sharp. But just how should we teach this knowledge to the FCN segmentation model?
While it’s not easy to mathematically encode the knowledge (for example, exactly how sharp is a sharp angle?), it’s pretty easy to tell whether the predicted segmentation looks natural or not. In machine learning lingo that’s called a binary classification problem. This naturally leads to the following adversarial framework:
Overview of the proposed Structure Correcting Adversarial Network (SCAN) framework that jointly trains a segmentation network and a critic network in an adversarial setting. The segmentation network produces per-pixel class prediction. The critic takes either the ground truth label or the prediction by the segmentation network, optionally with the CXR image, and outputs the probability estimates of whether the input is the ground truth (with training target 1) or the segmentation network prediction (with training target 0).
The key addition here is that segmentation network’s prediction is evaluated not only by the per-pixel loss (i.e. how well the predicted mask matches the ground truth pixel by pixel), but an “overall look and feel” evaluation given by the critic network (i.e., how well the predicted mask looks real enough to fool the critic network). Astute readers might notice that this is very similar to Generative Adversarial Networks (GAN). Indeed, this framework can be viewed as conditional GAN, where we generate the masks based on an input CXR image instead of a random noise vector in the original GAN.
In our work we design the critic network to largely mirror the segmentation network’s architecture. Details such as training objectives, hyperparameters of the model, and experiment setups can be found in our paper.
Performance of SCAN
Before we dive into the numbers, we should clarify that the critic network in SCAN is only involved during the training stage. During testing, we only use the segmentation network, which has an identical architecture as FCN. In other words, our hope is that with the addition of critic network we can somehow train the same segmentation network better, using the guidance from the critic network to encourage the same segmentation network towards more “natural” predictions. With that in mind, we repeat the evaluation on the JSRT dataset:
Evaluation on JSRT dataset. FCN is segmentation model only. Registration-based method is the prior state of the art for lung field segmentation (no heart) from (Candemir et. al., TMI (2014))
Notice that without any change in FCN architecture, SCAN improves FCN by 1.8% absolutely to human level performance, at around 94.6% lung IoU! Let’s revisit the 4 difficult patients in our failure cases:
Each column is a patient. The left two columns are patients from the JSRT evaluation set with models trained on JSRT development set. The right two columns are from the Montgomery dataset using a model trained on the full JSRT dataset only (but no Montgomery data), which is a much more challenging scenario. Note that only the two patients from JSRT dataset (left two columns) have heart annotations for evaluation of heart area IoU. These examples aren’t cherry picked results, but are in fact the more difficult cases. For example, notice that the 91.4% lung IoU by SCAN in the left most column is already much below the average 94.7% IoU in our evaluation (see the evaluation table above).
As you can see, all 4 cases are “fixed” pretty satisfactorily. Furthermore, notice that SCAN produces the more realistic sharp angle at the outer lower corner of each lung field (the costophrenic angle) compared with SCAN. The corners generally don’t affect the per-pixel performance, but can be important in downstream diagnostic tasks (e.g., detecting the blunting of costophrenic angle).
In the clinical settings it’s not enough to just have a good average performance, but it’s important to avoid outrageous errors in prediction as they can affect doctors’ trust in AI. By using the adversarial learning framework, SCAN improves the per-pixel metrics as well as “overall look and feel” of the prediction. Both of which are important in the clinical settings.
Comparison with Prior State of the Art for CXR Segmentation
The evaluation table above shows that our method outperforms the prior state of the art for CXR lung field segmentation (“registration-based” method) by a large margin. Since our work is the first deep learning solution for CXR segmentation, it’s helpful to have a perspective of how complex non-deep learning solutions can be:
The CXR lung segmentation pipeline used in (Candemir et. al., TMI (2014))
The approach in Candemir et. al., TMI (2014) involves a series of sift feature extraction, shape transformation, finding patients with similar lung shape profiles as candidate CXR segmentations, graph cut etc to produce the final segmentation. Each stage requires various tuning parameters, and since the prediction is based on deforming patients with similar lung profiles, when the new patient’s lung is sufficiently different from the existing training data, the performance suffers, as we will see later.
The complex pipeline in Candemir et. al., TMI (2014) stands in stark contrast with the simplicity of neural networks, where the network learns both the features and shapes on its own. Gone are the days for handcrafted features like SIFT and delicate shape manipulations in a series of stages.
It’s helpful to have some qualitative comparison to understand how SCAN outperforms Candemir et. al., TMI (2014):
The left two columns are from the JSRT evaluation set (using model trained on JSRT development set), and the right two columns are from the Montgomery set (using model trained on the full JSRT dataset).
For the left two columns SCAN produces more realistic contours around the sharp costophrenic angles. This may be a challenge in registration-based models where detecting and matching the costophrenic point is difficult. For the right two columns (Candemir et. al., TMI (2014)) struggles due to the mismatch between test patient lung profiles (from Montgomery dataset) and the existing lung profiles in the JSRT dataset, leading to the unnatural mask shapes.
There’s been much hype around AI’s diagnostic accuracy on CXRs. However, AI-based diagnosis on CXR can be fraught with suspicion from radiologists. While there are exciting results, it’s often easier to make inroad to hospitals with smaller improvements like cardiothoracic ratio (CTR) calculation that can be derived from lung segmentations (see Part 1). We were able to go into trials with our CTR engine quickly. Automated CTR calculation is easy to interpret, and generally very accurate. We’ve found that sometimes it’s more important to gain trust from the doctors and domain experts by supporting their existing workflow well with robust AI, instead of changing their workflow with less mature AI solutions. I hope that this case study can serve as a helpful example for the development of other healthcare AI solutions.