Visualizations: Correlation Charts

Introduction

Choosing the correct visualization is an important aspect of data presentation. However, many users often struggle to identify the most effective visualization for their data. Each type of visualization serves a different purpose, and selecting the appropriate one requires an understanding of the data, the audience, and the overall message you wish to convey.

This article aims to make the process of data visualization easier to understand. It will highlight the different types of graphs and their typical use cases. Additionally, it will provide you with the dataset used for each visualization, along with the Python and R code involved in creating the graph. You can see the full article here.

Acknowledgements
Each dataset used in this document (unless otherwise stated) can be found on vincentarelbundock.github.io, which is a large repository for datasets that can be used in R. I would like to thank the people responsible for making this information open access and accessible. The link to the google sheet will be provided throughout the document.

How the Guide is Formatted

The guide will be formatted where it will list a general group (i.e., comparison charts, correlation, etc.) followed by a list of visualizations that fall under that group. For example, bar/column charts are known as a type of comparison chart. Then, after a short introduction on the chart, a visualization will follow. Below the figure, the R and Python code will be displayed that was used to generate the graph. The code that is related to the visualization is listed directly underneath the figure. For all visualizations, make sure that you upload the file when you start the chat, as some of the code does not reflect that initial step.

Correlation

Correlation graphs are used to visualize relationships between variables, showing how one variable changes in relation to another. They often show the strength and direction of these relationships, which is important in fields like statistics, economics, and data science.

1. Heatmap/Correlation Matrices

Heatmaps and correlation matrices are great visualizations that are simple for readers to understand. They use colours to represent the value of variables in a two-dimensional space, often in a matrix form.

For this visualization, we will use a dataset called ‘cerebellum_gene_expression2, accessible here.

R & Python Example

R Code

#R CODE
#Load necessary libraries
library(googlesheets4)
library(dplyr)
library(corrplot)
library(viridis)

#Read the data
gs4_deauth()
df <- read_sheet('https://docs.google.com/spreadsheets/d/1_TMv0gLaQ70oqOYw3l0Y_M9BzN5C5Utunhdq1paC1bs/edit?usp=sharing')

#Select gene columns (excluding 'rownames' and 'cerebellum')
gene_cols <- colnames(df)[!colnames(df) %in% c('rownames', 'cerebellum')]

#Randomly select 20 genes
set.seed(123)  # for reproducibility
selected_genes <- sample(gene_cols, 20)

#Create a subset of the data with only the selected genes
df_subset <- df[, c('rownames', selected_genes)]

#Calculate correlation matrix
cor_matrix <- cor(df_subset[, -1])

#Create a correlation plot with viridis color palette
png("gene_correlation_plot_colorblind_friendly.png", width = 800, height = 800)
corrplot(cor_matrix, method = "color", type = "full", order = "hclust", 
         col = viridis(200), tl.col = "black", tl.srt = 45, addCoef.col = "black", 
         number.cex = 0.7, tl.cex = 0.7)
dev.off()

#Read the image file and encode it as base64
library(base64enc)
img <- base64encode("gene_correlation_plot_colorblind_friendly.png")
cat(paste0("data:image/png;base64,", img))

#Print the selected genes
print("Selected genes:")
print(selected_genes)

#Display the first few rows and columns of the correlation matrix
print("First few rows and columns of the correlation matrix:")
print(cor_matrix[1:5, 1:5])		

Python Code

#PYTHON CODE
#Import necessary libraries
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from google.oauth2 import service_account
import gspread
from oauth2client.service_account import ServiceAccountCredentials

#Define the scope
scope = [
    'https://spreadsheets.google.com/feeds',
    'https://www.googleapis.com/auth/drive',
]

#Add credentials to the account
creds = ServiceAccountCredentials.from_json_keyfile_name('path/to/credentials.json', scope)

#Authorize the clientsheet 
client = gspread.authorize(creds)

#Get the instance of the Spreadsheet
sheet = client.open('cerebellum_gene_expression2')

#Get the first sheet of the Spreadsheet
sheet_instance = sheet.get_worksheet(0)

#Get all the records of the data
records_data = sheet_instance.get_all_records()

#Convert the json to dataframe
records_df = pd.DataFrame.from_dict(records_data)

#Select gene columns (excluding 'rownames' and 'cerebellum')
gene_cols = [col for col in records_df.columns if col not in ['rownames', 'cerebellum']]

#Randomly select 20 genes
np.random.seed(123)  # for reproducibility
selected_genes = np.random.choice(gene_cols, 20, replace=False)

#Create a subset of the data with only the selected genes
df_subset = records_df[['rownames'] + list(selected_genes)]

#Calculate correlation matrix
cor_matrix = df_subset.iloc[:, 1:].corr()

#Create a correlation plot with a colorblind-friendly palette
plt.figure(figsize=(10, 10))
sns.heatmap(cor_matrix, annot=True, fmt='.2f', cmap='viridis', cbar=True, square=True, 
            linewidths=.5, linecolor='black', xticklabels=True, yticklabels=True)
plt.title('Gene Correlation Plot (Colorblind Friendly)')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()

#Save the plot
plt.savefig('gene_correlation_plot_colorblind_friendly.png')
plt.show()

#Print the selected genes and the first few rows and columns of the correlation matrix
print('Selected genes:')
print(selected_genes)
print('First few rows and columns of the correlation matrix:')
print(cor_matrix.iloc[:5, :5])

2. Bubble chart

A bubble chart is a data visualization technique that displays three dimensions of data in a two-dimensional plot. The ‘bubbles’ represent data points.

The dataset used to create this graph was from the 2000 US census, and can be accessed here. We will investigate the relationship between education level, poverty, total population and population density in 15 counties from Illinois.

R Example

#R CODE
#Load necessary libraries
library(ggplot2)
library(dplyr)

# Read the CSV file
gs4_deauth()
df <- read_sheet('https://docs.google.com/spreadsheets/d/1-pdUfk27NvK2L8Jxh5q0EW00FeYovwFPSl-fMefmIG0/edit?usp=sharing')

#Filter for Illinois counties and select the top 15 by population
il_counties <- df %>% filter(state == 'IL') %>% top_n(15, poptotal)

#Create the bubble chart
p <- ggplot(il_counties, aes(x = perchsd, y = percbelowpoverty, size = poptotal, color = popdensity)) +
  geom_point(alpha = 0.7) +
  scale_size_continuous(name = "Total Population", range = c(3, 15), breaks = c(100000, 500000, 1000000), labels = c("100k", "500k", "1M")) +
  scale_color_viridis_c(name = "Population Density") +
  labs(title = "Top 15 Illinois Counties: Education, Poverty, Population, and Density",
       x = "Percentage with High School Diploma",
       y = "Percentage Below Poverty Line") +
  theme_minimal() +
  theme(legend.position = "right") +
  geom_text(aes(label = county), hjust = 0.5, vjust = -1, size = 3)

#Save the plot
ggsave("top_15_illinois_counties_bubble_chart_r.png", plot = p, width = 14, height = 9)

#Print confirmation message
print("Bubble chart for top 15 Illinois counties has been created and saved as 'top_15_illinois_counties_bubble_chart_r.png'.")

#Display the data for these top 15 Illinois counties
print("Data for the top 15 Illinois counties:")
print(il_counties %>% select(county, poptotal, popdensity, perchsd, percbelowpoverty))

Python Example
image

#PYTHON CODE
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

#Read the CSV file
df = pd.read_csv("google_sheets_data.csv")

#Filter for Illinois counties and select the top 15 by population
il_counties = df[df['state'] == 'IL'].nlargest(15, 'poptotal')

#Create the bubble chart
plt.figure(figsize=(14, 9))  # Increased figure size to accommodate legend
scatter = plt.scatter(il_counties['perchsd'], il_counties['percbelowpoverty'], 
                      s=il_counties['poptotal']/1000, c=il_counties['popdensity'], 
                      cmap='viridis', alpha=0.7)

plt.xlabel('Percentage with High School Diploma')
plt.ylabel('Percentage Below Poverty Line')
plt.title('Top 15 Illinois Counties: Education, Poverty, Population, and Density')

#Add county names as labels
for i, row in il_counties.iterrows():
    plt.annotate(row['county'], (row['perchsd'], row['percbelowpoverty']), 
                 xytext=(5, 5), textcoords='offset points', fontsize=8)

#Add a colorbar
cbar = plt.colorbar(scatter)
cbar.set_label('Population Density')

#Add a size legend
sizes = [100000, 500000, 1000000]
labels = ['100k', '500k', '1M']
legend_elements = [plt.scatter([], [], s=size/1000, c='gray', alpha=0.6, label=label)
                   for size, label in zip(sizes, labels)]
plt.legend(handles=legend_elements, title='Total Population', loc='upper right', bbox_to_anchor=(1.25, 1))

plt.tight_layout()
plt.savefig('top_15_illinois_counties_bubble_chart_final_adjusted.png', bbox_inches='tight')
plt.close()

print("Bubble chart for top 15 Illinois counties with final adjusted legend placement has been created and saved as 'top_15_illinois_counties_bubble_chart_final_adjusted.png'.\
")

#Display the data for these top 15 Illinois counties
print("Data for the top 15 Illinois counties:")
print(il_counties[['county', 'poptotal', 'popdensity', 'perchsd', 'percbelowpoverty']])

3. Scatter Plot
A scatter plot is a type of data visualization technique that displays values for two variables for a set of data points. It typically illustrates how one variable is affected by another, revealing potential relationships between them.

For this visualization, we are using a dataset called ‘insurance’, which can be accessed here. This dataset includes data on monthly quotations and monthly television advertising expenditure from a US-based insurance company, collected from January 2002 to April 2002.

R Example

# RCODE
# Load necessary libraries
library(ggplot2)

# Read the Excel file
library(readxl)
file_path <- "example.xlsx"
df <- read_excel(file_path, sheet = "Sheet1")

# Create a scatterplot with a trendline
plot <- ggplot(df, aes(x = TV.advert, y = Quotes)) +
  geom_point(alpha = 0.5) +
  geom_smooth(method = "lm", se = FALSE, color = "blue") +
  labs(title = "Scatterplot of TV Advertisements vs Quotes",
       x = "TV Advertisements",
       y = "Quotes") +
  theme_minimal()

# Display the plot
print(plot)

Python Example

#PYTHON CODE
# Import necessary libraries for plotting
import matplotlib.pyplot as plt
import seaborn as sns

# Set the style of the visualization
sns.set(style="whitegrid")

# Create a scatterplot with a trendline
plt.figure(figsize=(10, 6))
sns.regplot(x="TV.advert", y="Quotes", data=df, ci=None, scatter_kws={"s": 50, "alpha": 0.5})

# Add titles and labels
plt.title("Scatterplot of TV Advertisements vs Quotes")
plt.xlabel("TV Advertisements")
plt.ylabel("Quotes")

# Show the plot
plt.show()

4. Hexagonal binning

Hexagonal binning is a technique used for large, dense datasets with continuous numerical data in two dimensions. It helps display the distribution and density of points, particularly useful when overplotting occurs, making it difficult to discern trends.

For this visualization, we will use a dataset containing daily observations made for the S&P 500 stock market from 1950 to 2018. It can be accessed here.

R Example

#R CODE
# Load necessary libraries
library(ggplot2)
library(hexbin)

# Read the CSV file
df <- read.csv('sp500_1950_2018.csv')

# Create a hexagonal binning visualization for 'Close' and 'Volume'
p <- ggplot(df, aes(x = Close, y = Volume)) +
  geom_hex(bins = 50) +
  scale_fill_viridis_c(trans = "log", name = "log10(count) in bin") +
  labs(title = "Hexagonal Binning of Close Price vs Volume (Log Scale)",
       x = "Close Price",
       y = "Volume") +
  theme_minimal()

# Display the plot
print(p)

# Print basic statistics
cat("Close price range:\
")
cat(sprintf("Min: %.2f\
", min(df$Close)))
cat(sprintf("Max: %.2f\
", max(df$Close)))
cat(sprintf("Mean: %.2f\
", mean(df$Close)))
cat(sprintf("Median: %.2f\
", median(df$Close)))

cat("\
Volume range:\
")
cat(sprintf("Min: %s\
", format(min(df$Volume), big.mark = ",")))
cat(sprintf("Max: %s\
", format(max(df$Volume), big.mark = ",")))
cat(sprintf("Mean: %s\
", format(mean(df$Volume), big.mark = ",", scientific = FALSE)))
cat(sprintf("Median: %s\
", format(median(df$Volume), big.mark = ",", scientific = FALSE)))

# Calculate correlation
correlation <- cor(df$Close, df$Volume)
cat(sprintf("\
Correlation between Close price and Volume: %.4f\
", correlation))

Python Example

#PYTHON CODE
# Import necessary libraries
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Read the CSV file
df = pd.read_csv('sp500_1950_2018.csv')

# Create a hexagonal binning visualization for 'Close' and 'Volume'
plt.figure(figsize=(12, 10))

# Create the hexbin plot
hb = plt.hexbin(df['Close'], df['Volume'], gridsize=50, cmap='viridis', bins='log')

# Add color bar
cb = plt.colorbar(label='log10(count) in bin')

# Add labels and title
plt.xlabel('Close Price')
plt.ylabel('Volume')
plt.title('Hexagonal Binning of Close Price vs Volume (Log Scale)')

# Show the plot
plt.show()

# Print some basic statistics
print("Close price range:")
print(f"Min: {df['Close'].min():.2f}")
print(f"Max: {df['Close'].max():.2f}")
print(f"Mean: {df['Close'].mean():.2f}")
print(f"Median: {df['Close'].median():.2f}")

print("\
Volume range:")
print(f"Min: {df['Volume'].min():,}")
print(f"Max: {df['Volume'].max():,}")
print(f"Mean: {df['Volume'].mean():,.0f}")
print(f"Median: {df['Volume'].median():,.0f}")

# Calculate correlation
correlation = df['Close'].corr(df['Volume'])
print(f"\
Correlation between Close price and Volume: {correlation:.4f}")

5. Contour plot + Surface Plot

This is another technique that is used for visualizing data distributions and densities within a two dimensional field. It can represent three-dimensional data in two dimensions, with the third dimension representing the contour lines. It is oftentimes used to create topographic maps of the data. For simplicity, we are going to plot the function Z = sin(sqrt(X^2 + Y^2)).

R Example


#R CODE
# Install and load necessary libraries
if (!requireNamespace("plotly", quietly = TRUE)) {
  install.packages("plotly", repos = "https://cran.rstudio.com/", dependencies = TRUE, Ncpus = 4)
}
library(plotly)

# Create a grid of x and y values
x <- seq(-10, 10, length.out = 100)
y <- seq(-10, 10, length.out = 100)
grid <- expand.grid(x = x, y = y)

# Calculate z values
grid$z <- with(grid, sin(sqrt(x^2 + y^2)))

# Reshape the data for plotting
z_matrix <- matrix(grid$z, nrow = length(x), ncol = length(y))

# Create the surface plot
surface_plot <- plot_ly(x = x, y = y, z = z_matrix) %>%
  add_surface() %>%
  layout(
    scene = list(
      xaxis = list(title = "X"),
      yaxis = list(title = "Y"),
      zaxis = list(title = "Z")
    ),
    title = "Surface Plot of Z = sin(sqrt(X^2 + Y^2))"
  )

# Create the contour plot
contour_plot <- plot_ly(x = x, y = y, z = z_matrix, type = "contour") %>%
  layout(
    xaxis = list(title = "X"),
    yaxis = list(title = "Y"),
    title = "Contour Plot of Z = sin(sqrt(X^2 + Y^2))"
  )

# Display the plots
print(surface_plot)
print(contour_plot)

# Print a message to confirm execution
cat("Plots have been generated. You should see them in the output.")

Python Example

#PYTON CODE
from mpl_toolkits.mplot3d import Axes3D

# Create a figure for the surface plot
fig = plt.figure(figsize=(12, 6))

# Add a subplot for the surface plot
ax = fig.add_subplot(121, projection='3d')
surface = ax.plot_surface(X, Y, Z, cmap='viridis')
fig.colorbar(surface, ax=ax, shrink=0.5, aspect=5)
ax.set_title('Surface Plot of $Z = \sin(\sqrt{X^2 + Y^2})$')
ax.set_xlabel('X-axis')
ax.set_ylabel('Y-axis')
ax.set_zlabel('Z-axis')

# Add a subplot for the contour plot
ax2 = fig.add_subplot(122)
contour = ax2.contour(X, Y, Z, levels=20, cmap='viridis')
fig.colorbar(contour, ax=ax2)
ax2.set_title('Contour Plot of $Z = \sin(\sqrt{X^2 + Y^2})$')
ax2.set_xlabel('X-axis')
ax2.set_ylabel('Y-axis')
ax2.grid(True)

plt.show()

This post is part of a multi-series compilation. You can find the other posts below:

Visualization: Geospatial and Other Charts

Visualization: Data Over Time (Temporal)

Visualization: Distribution Charts

Visualization: Part-to-Whole Charts

Visualization: Comparison Charts

Happy graphing!

1 Like

Thanks for the amazing walkthroughs :heart:

1 Like