This site is protected by reCAPTCHA and the Google Privacy Policy and Terms of Service apply.

June 19, 2023

Scraping Google Images with Python to Generate Labeled Datasets

Scrape categories of images to generate labeled datasets

Labeled images, while necessary for many machine-learning applications, can be hard to get; how can we automate the data-gathering process to make labeled datasets?

Photo by Andrew S on Unsplash

Introduction

Computer vision and image generation models need data–a lot of data. Unfortunately, in some circumstances, labeled image datasets aren’t available. Sometimes, those datasets don’t exist, or perhaps the data is too sparse to be helpful. So how do you make your own image datasets?

One option is to scrape Google Images. Using only text descriptions, we can quickly scrape relevant images into groups to train our deep-learning models. I'll provide the code I use to do this and an in-depth explanation of every relevant part.

Today, we're going to make a dataset of cats and dogs. The idea is that you'll be able to replace these categories with whichever ones you want.

Getting Started

To begin there are several libraries that you'll need to install:

pip install selenium
pip install webdriver_manager
pip install pillow
pip install requests

Selenium allows us to automate the browser. In our code, we will be automating Chrome using the ChromeDriver, a browser automation tool.

from selenium import webdriver
from selenium.webdriver.chrome.service import Service
from webdriver_manager.chrome import ChromeDriverManager
from selenium.webdriver.common.by import By
import os
import base64
import shutil
import requests
from PIL import Image
from io import BytesIO


driver = webdriver.Chrome(service=Service(ChromeDriverManager().install()))
driver.get('https://images.google.com')

If the code works, an automated browser window should open up on the Google Images webpage, like so:

Now that we've got our setup working, let's search for one of our categories: dogs.

Searching for images on Google is relatively straightforward. The URL format is as follows:

  • https://www.google.com/search?tbm=isch&q=(insert query here)
    • isch stands for "image search"

Now, let's make a function to format the URL with our search query.

def get_url(search):
    url_format = 'https://www.google.com/search?tbm=isch&q={}'
    search = search.replace(' ', ' ')
    url = url_format.format(search)
    return url


driver.get(get_url('dogs'))

Now let's try extracting the image srcs from the results page so we can download the images. First, we need to find a way to identify the image elements on the webpage. Google's id and classes don't have intuitive names, so I'm going to use an XPath expression.

all_imgs = driver.find_elements(By.XPATH, '//*[@class="rg_i Q4LuWd"]')

XPATH expressions are, of course, vulnerable to layout changes. In the case that this code doesn't work, try changing the @class="..." portion to the new class for image results.

When we're ready to get the page results, we can extract the src from each of the image element like so:

all_srcs = [x.get_attribute('src') for x in all_imgs]

Scrolling

You may have noticed that Google doesn't load all image search results at first; we need to scroll down to see more.

In addition, after some time, google will display a "Show More Results" button, which we'll need to click. That button has the class "mye4qd".

def handle_scroll(driver):
    # scrolls to the bottom of the page
    # show more results button: mye4qd
    # you've reached the end: OuJzKb Yu2Dnd
    last_height = driver.execute_script("return window.pageYOffset")
    num_trials = 0
    max_trials = 3
    while True:
        try:
            show_more_results_btn = driver.find_element(By.CLASS_NAME, 'mye4qd')
            show_more_results_btn.click()
            continue
        except:
            pass

        driver.execute_script('window.scrollBy(0,100);')
        new_height = driver.execute_script("return window.pageYOffset")

        if new_height - last_height < 10:
            num_trials += 1
            if num_trials >= max_trials:
                break
            time.sleep(1)
        else:
            last_height = new_height
            num_trials = 0

The function handle_scroll works by:

  1. Finding out the window's initial pageYOffset
  2. Scrolling down the page until the pageYOffset stops increasing, clicking the "Show More Results" button as necessary. To reduce error, such as momentary page lags that prevent scrolling, we take several measurements of the page offset before exiting the while loop.

Creating the Images Directory

def setup(path, delete_all=True):
    path_exists = os.path.exists(path)
    if not path_exists:
        os.mkdir(path)
    elif delete_all and path_exists:
        shutil.rmtree(path)
        os.mkdir(path)

In the last code block, we introduced two new functions; the first, setup , will take in a folder path and erase its contents, recreating the path so that our program can start afresh and get new results. This way, we have folders to store our images.

Scraping Image Categories

def scrape_category(category_name, queries, data_dir):
    setup(os.path.join(data_dir, category_name))
    category_dir = os.path.join(data_dir, category_name)

    for query in queries:
        print('Scraping category: {} for query: {}'.format(category_name, query))
        driver.get(get_url(query))
        handle_scroll(driver)
        all_imgs = driver.find_elements(By.XPATH, '//*[@class="rg_i Q4LuWd"]')
        all_srcs = [x.get_attribute('src') for x in all_imgs]
        save_images(all_srcs, category_name, query, category_dir)

categories = {
    'dog': ['dog', 'golden retriever', 'husky dog', 'bulldog', 'dalmatian', 'poodle'],
    'cat': ['cat', 'siamese cat', 'persian cat', 'ragdoll cat', 'shorthair cat']
}

for category in categories:
    category_queries = categories[category]
    scrape_category(category, category_queries, data_dir=data_path)

Google limits the number of images shown per search. To circumvent this issue, we can provide more specific keywords to our program. For example, under the "dog" category we can scrape "golden retriever" and "poodle".

For each category (the keys of the categories dictionary), we scrape by keywords, generating search URLs using the get_url function, which we then pass to the driver. Once arriving on a page, we scroll using our handle_scroll function. Upon arriving at the end of a search results page, we find all the available images like so:

all_imgs = driver.find_elements(By.XPATH, '//*[@class="rg_i Q4LuWd"]')
all_srcs = [x.get_attribute('src') for x in all_imgs]

The driver will search for elements on the page matching the specified XPATH, and will then collect the "src" attribute from each element. From there, we can save the images to our category directory. On the first pass, we will be saving images into the "data/dogs" folder.

Saving Images

def save_images(srcs, category, keyword, data_dir):
    print('Done scraping, writing {} images to files'.format(category))
    srcs = [x for x in srcs if x]
    for i, src in enumerate(srcs):
        unique_image_name = '{}_{}.png'.format(keyword.replace(' ', '_').replace(':', '_').replace('/', '_'), i)
        file_path = os.path.join(data_dir, unique_image_name)

        if 'data:image' in src:
            readable_base64 = ''.join(src.split('base64,')[1:])
            content = base64.b64decode(readable_base64)
        else:
            r = requests.get(src)
            content = r.content

        img = Image.open(BytesIO(content)).convert('RGB')
        img.save(file_path)

The save_images function takes in:

  • The image element srcs, which will either be image urls or base64 strings (eg, data:image/png;base64,...)
  • The current category of images being scraped (in our case, dogs)
  • The keywords string in the category being scraped (such as "golden retriever", "poodle")
  • data_dir , the directory to save images into.

The function will loop through the image srcs and either reformat base64 strings or download the images from their urls using the requests library. From there, the raw image files are converted to RGB and saved via the Pillow library, which is useful for image processing and image saving.

Image filenames are a combination of the current category and keyword, with special characters (such as : and /) replaced to avoid errors.

Running the complete source code provided should result in nice, organized image datasets. Edit the categories and keywords as necessary for your use case.