Skip to main content

Mask R-CNN is not loading weights properly for inference and re-training

QUESTION:

I'm new to the world of computer vision and this is my second project with it. I am running an edited version of the Matterport Mask RCNN that runs with tensorflow-gpu==2.7.0. (Found out later it would have worked out just fine with an older version) I am trying to use this with a pen data set I created.

Anyway, the problem I am having is whenever I load the trained weights into the model to resume training it, the metrics all skyrocket back up. I am also getting bad predictions loading them for inference as well. Why are my weights not loading or saving properly? I am saving the weights using callbacks and loading them using the following:

model = modellib.MaskRCNN(mode="inference", 
                          config=inference_config,
                          model_dir=MODEL_DIR)

# Get path to saved weights
model_path = model.find_last()

# Load trained weights
print("Loading weights from ", model_path)
model.load_weights(model_path, by_name=True)

WHAT I'VE TRIED:

I have tried saving the whole model by changing the save_weights_only in the callbacks to False. I ran into the get_config() issue in this thread and followed through on some of those solutions, but to no avail.

I have also tried messing around with image sizes and epoch number as well.

I have tried saving the model using:

from tensorflow import keras
model.keras_model.save(complete filepath)
model = keras.models.load_model('path/to/location')

which led to the same get_config() issue.

RESOURCES:

Here is a list of the things I am running:

 # ITEM ########### VERSION ##########################
 # Python         # 3.9.7                            #
 # conda          # 4.10.3                           # 
 # CUDA           # 11.4                             #
 # WindowsOS      # 11                               #
 # cuDNN          # 8.2.4                            # 
 #####################################################
 
################################### PACKAGES ##################################
# packages in environment at C:\Users\ecsan\anaconda3\envs\Prototype:
# Command: conda list
# Name #################### Version ################ Build # Channel ############
# absl-py                   1.0.0                    pypi_0    pypi           #
# alabaster                 0.7.12                   pypi_0    pypi           #
# argon2-cffi               21.1.0                   pypi_0    pypi           #
# astunparse                1.6.3                    pypi_0    pypi           #
# attrs                     21.2.0                   pypi_0    pypi           #
# babel                     2.9.1                    pypi_0    pypi           #
# backcall                  0.2.0                    pypi_0    pypi           #
# bleach                    4.1.0                    pypi_0    pypi           #
# ca-certificates           2021.10.8            h5b45459_0    conda-forge    # 
# cachetools                4.2.4                    pypi_0    pypi           #
# certifi                   2021.10.8                pypi_0    pypi           #
# cffi                      1.15.0                   pypi_0    pypi           #
# charset-normalizer        2.0.9                    pypi_0    pypi           #
# colorama                  0.4.4                    pypi_0    pypi           #
# console_shortcut          0.1.1                         4                   #
# cycler                    0.11.0                   pypi_0    pypi           #
# cython                    0.29.25                  pypi_0    pypi           #
# debugpy                   1.5.1                    pypi_0    pypi           #
# decorator                 5.1.0                    pypi_0    pypi           #
# defusedxml                0.7.1                    pypi_0    pypi           #
# dill                      0.3.4                    pypi_0    pypi           #
# docutils                  0.17.1                   pypi_0    pypi           #
# entrypoints               0.3                      pypi_0    pypi           #
# flatbuffers               2.0                      pypi_0    pypi           #
# fonttools                 4.28.3                   pypi_0    pypi           #
# gast                      0.4.0                    pypi_0    pypi           #
# google-auth               2.3.3                    pypi_0    pypi           #
# google-auth-oauthlib      0.4.6                    pypi_0    pypi           #
# google-pasta              0.2.0                    pypi_0    pypi           #
# grpcio                    1.42.0                   pypi_0    pypi           #
# h5py                      3.6.0                    pypi_0    pypi           #
# idna                      3.3                      pypi_0    pypi           #
# imageio                   2.13.2                   pypi_0    pypi           #
# imagesize                 1.3.0                    pypi_0    pypi           #
# imgaug                    0.4.0                    pypi_0    pypi           #
# importlib-metadata        4.8.2                    pypi_0    pypi           #
# ipykernel                 6.6.0                    pypi_0    pypi           #
# ipyparallel               8.0.0                    pypi_0    pypi           #
# ipython                   7.30.1                   pypi_0    pypi           #
# ipython-genutils          0.2.0                    pypi_0    pypi           #
# ipywidgets                7.6.5                    pypi_0    pypi           #
# jedi                      0.18.1                   pypi_0    pypi           #
# jinja2                    3.0.3                    pypi_0    pypi           #
# joblib                    1.1.0                    pypi_0    pypi           #
# jsonschema                4.2.1                    pypi_0    pypi           #
# jupyter-client            7.1.0                    pypi_0    pypi           #
# jupyter-core              4.9.1                    pypi_0    pypi           #
# jupyterlab-pygments       0.1.2                    pypi_0    pypi           #
# jupyterlab-widgets        1.0.2                    pypi_0    pypi           #
# keras                     2.7.0                    pypi_0    pypi           #
# keras-preprocessing       1.1.2                    pypi_0    pypi           #
# kiwisolver                1.3.2                    pypi_0    pypi           #
# libclang                  12.0.0                   pypi_0    pypi           #
# markdown                  3.3.6                    pypi_0    pypi           #
# markupsafe                2.0.1                    pypi_0    pypi           #
# matplotlib                3.5.0                    pypi_0    pypi           #
# matplotlib-inline         0.1.3                    pypi_0    pypi           #
# mistune                   0.8.4                    pypi_0    pypi           #
# nbclient                  0.5.9                    pypi_0    pypi           #
# nbconvert                 6.3.0                    pypi_0    pypi           #
# nbformat                  5.1.3                    pypi_0    pypi           #
# nest-asyncio              1.5.4                    pypi_0                   #
# networkx                  2.6.3                    pypi_0    pypi           #
# nose                      1.3.7                    pypi_0    pypi           #
# notebook                  6.4.6                    pypi_0    pypi           #
# numpy                     1.19.5                   pypi_0    pypi           #
# oauthlib                  3.1.1                    pypi_0    pypi           #
# opencv-python             4.5.4.60                 pypi_0    pypi           #
# openssl                   3.0.0                h8ffe710_2    conda-forge    #
# opt-einsum                3.3.0                    pypi_0    pypi           #
# packaging                 21.3                     pypi_0    pypi           #
# pandocfilters             1.5.0                    pypi_0    pypi           #
# parso                     0.8.3                    pypi_0    pypi           #
# pickleshare               0.7.5                    pypi_0    pypi           #
# pillow                    8.4.0                    pypi_0    pypi           #
# pip                       21.3.1             pyhd8ed1ab_0    conda-forge    #
# prometheus-client         0.12.0                   pypi_0    pypi           #
# prompt-toolkit            3.0.23                   pypi_0    pypi           #
# protobuf                  3.19.1                   pypi_0    pypi           #
# psutil                    5.8.0                    pypi_0    pypi           #
# pyasn1                    0.4.8                    pypi_0    pypi           #
# pyasn1-modules            0.2.8                    pypi_0    pypi           #
# pycparser                 2.21                     pypi_0    pypi           #
# pygments                  2.10.0                   pypi_0    pypi           #
# pyparsing                 3.0.6                    pypi_0    pypi           #
# pyrsistent                0.18.0                   pypi_0    pypi           #
# python                    3.9.7        h900ac77_3_cpython    conda-forge    #
# python-dateutil           2.8.2                    pypi_0    pypi           #
# python_abi                3.9                      2_cp39    conda-forge    #
# pytz                      2021.3                   pypi_0    pypi           #
# pywavelets                1.2.0                    pypi_0    pypi           #
# pywin32                   302                      pypi_0    pypi           #
# pywinpty                  1.1.6                    pypi_0    pypi           #
# pyzmq                     22.3.0                   pypi_0    pypi           #
# qtconsole                 5.2.1                    pypi_0    pypi           #
# qtpy                      1.11.3                   pypi_0    pypi           #
# requests                  2.26.0                   pypi_0    pypi           #
# requests-oauthlib         1.3.0                    pypi_0    pypi           #
# rsa                       4.8                      pypi_0    pypi           #
# scikit-image              0.18.3                   pypi_0    pypi           #
# scipy                     1.7.3                    pypi_0    pypi           #
# send2trash                1.8.0                    pypi_0    pypi           #
# setuptools                59.4.0           py39hcbf5309_0    conda-forge    #
# setuptools-scm            6.3.2                    pypi_0    pypi           #
# shapely                   1.8.0                    pypi_0    pypi           #
# six                       1.15.0                   pypi_0    pypi           #
# snowballstemmer           2.2.0                    pypi_0    pypi           #
# sphinx                    4.3.1                    pypi_0    pypi           #
# sphinxcontrib-applehelp   1.0.2                    pypi_0    pypi           #
# sphinxcontrib-devhelp     1.0.2                    pypi_0    pypi           #
# sphinxcontrib-htmlhelp    2.0.0                    pypi_0    pypi           #
# sphinxcontrib-jsmath      1.0.1                    pypi_0    pypi           #
# sphinxcontrib-qthelp      1.0.3                    pypi_0    pypi           #
# sphinxcontrib-serializinghtml 1.1.5                pypi_0    pypi           #
# sqlite                    3.37.0               h8ffe710_0    conda-forge    #
# tb-nightly                2.8.0a20211220           pypi_0    pypi           #
# tensorboard               2.7.0                    pypi_0    pypi           #
# tensorboard-data-server   0.6.1                    pypi_0    pypi           #
# tensorboard-plugin-wit    1.8.0                    pypi_0    pypi           #
# tensorflow-estimator      2.7.0                    pypi_0    pypi           #
# tensorflow-gpu            2.7.0                    pypi_0    pypi           #
# tensorflow-io-gcs-filesystem 0.23.1                pypi_0    pypi           #
# termcolor                 1.1.0                    pypi_0    pypi           #
# terminado                 0.12.1                   pypi_0    pypi           #
# testpath                  0.5.0                    pypi_0    pypi           #
# tf-estimator-nightly      2.8.0.dev2021122009      pypi_0    pypi           #
# tifffile                  2021.11.2                pypi_0    pypi           #
# tomli                     1.2.2                    pypi_0    pypi           #
# tornado                   6.1                      pypi_0    pypi           #
# tqdm                      4.62.3                   pypi_0    pypi           #
# traitlets                 5.1.1                    pypi_0    pypi           #
# typing-extensions         4.0.1                    pypi_0    pypi           #
# tzdata                    2021e                he74cb21_0    conda-forge    #
# ucrt                      10.0.20348.0         h57928b3_0    conda-forge    #
# urllib3                   1.26.7                   pypi_0    pypi           #
# vc                        14.2                 hb210afc_5    conda-forge    #
# vs2015_runtime            14.29.30037          h902a5da_5    conda-forge    #
# wcwidth                   0.2.5                    pypi_0    pypi           #
# webencodings              0.5.1                    pypi_0    pypi           #
# werkzeug                  2.0.2                    pypi_0    pypi           #
# wheel                     0.37.0             pyhd8ed1ab_1    conda-forge    #
# widgetsnbextension        3.5.2                    pypi_0    pypi           #
# wrapt                     1.13.3                   pypi_0    pypi           #
# zipp                      3.6.0                    pypi_0    pypi           #
###############################################################################

Here is a link to my tensorboard and an example of a bad prediction:

You should see the model learning and then a spike at the end, that spike was when I loaded the weights again and resumed training.

https://tensorboard.dev/experiment/KkgugOP7RGu12lVCA6M29Q/

Bad Prediction

Here is my custom config for training:

class CustomConfig(Config):
    """Configuration for training on the toy shapes dataset.
    Derives from the base Config class and overrides values specific
    to the toy shapes dataset.
    """
    """Configuration for training on the dataset.
    Derives from the base Config class and overrides some values.
    """    


    DETECTION_MIN_CONFIDENCE = 0.7 # Skip detections with < 90% confidence
    # Give the configuration a recognizable name
    NAME = "PEN"

    # Train on 1 GPU and 8 images per GPU. We can put multiple images on each
    # GPU because the images are small. Batch size is 8 (GPUs * images/GPU).
    GPU_COUNT = 1
    IMAGES_PER_GPU = 8

    # Number of classes (including background)
    NUM_CLASSES = 1 + 1  # background + PEN

    # Use small images for faster training. Set the limits of the small side
    # the large side, and that determines the image shape.
    IMAGE_MIN_DIM = 128
    IMAGE_MAX_DIM = 128

    # Use smaller anchors because our image and objects are small
    RPN_ANCHOR_SCALES = (8, 16, 32, 64, 128)  # anchor side in pixels

    # Reduce training ROIs per image because the images are small and have
    # few objects. Aim to allow ROI sampling to pick 33% positive ROIs.
    TRAIN_ROIS_PER_IMAGE = 32

    # Use a small epoch since the data is simple
    STEPS_PER_EPOCH = 300

    # use small validation steps since the epoch is small
    VALIDATION_STEPS = 10
    
config = CustomConfig()
config.display()

Here is my inference config:

class InferenceConfig(CustomConfig):
    NAME = "PEN"

    NUM_CLASSES = 1 + 1  # background + PEN

    # Use small images for faster training. Set the limits of the small side
    # the large side, and that determines the image shape.
    IMAGE_MIN_DIM = 128
    IMAGE_MAX_DIM = 128

    # Use smaller anchors because our image and objects are small
    RPN_ANCHOR_SCALES = (8, 16, 32, 64, 128)  # anchor side in pixels
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1
    DETECTION_MIN_CONFIDENCE = 0.9

If you need additional information please let me know. This is also my first post and any guidance is appreciated.



source https://stackoverflow.com/questions/70454143/mask-r-cnn-is-not-loading-weights-properly-for-inference-and-re-training

Comments

Popular posts from this blog

Prop `className` did not match in next js app

I have written a sample code ( Github Link here ). this is a simple next js app, but giving me error when I refresh the page. This seems to be the common problem and I tried the fix provided in the internet but does not seem to fix my issue. The error is Warning: Prop className did not match. Server: "MuiBox-root MuiBox-root-1" Client: "MuiBox-root MuiBox-root-2". Did changes for _document.js, modified _app.js as mentioned in official website and solutions in stackoverflow. but nothing seems to work. Could someone take a look and help me whats wrong with the code? Via Active questions tagged javascript - Stack Overflow https://ift.tt/2FdjaAW

How to show number of registered users in Laravel based on usertype?

i'm trying to display data from the database in the admin dashboard i used this: <?php use Illuminate\Support\Facades\DB; $users = DB::table('users')->count(); echo $users; ?> and i have successfully get the correct data from the database but what if i want to display a specific data for example in this user table there is "usertype" that specify if the user is normal user or admin i want to user the same code above but to display a specific usertype i tried this: <?php use Illuminate\Support\Facades\DB; $users = DB::table('users')->count()->WHERE usertype =admin; echo $users; ?> but it didn't work, what am i doing wrong? source https://stackoverflow.com/questions/68199726/how-to-show-number-of-registered-users-in-laravel-based-on-usertype

Why is my reports service not connecting?

I am trying to pull some data from a Postgres database using Node.js and node-postures but I can't figure out why my service isn't connecting. my routes/index.js file: const express = require('express'); const router = express.Router(); const ordersCountController = require('../controllers/ordersCountController'); const ordersController = require('../controllers/ordersController'); const weeklyReportsController = require('../controllers/weeklyReportsController'); router.get('/orders_count', ordersCountController); router.get('/orders', ordersController); router.get('/weekly_reports', weeklyReportsController); module.exports = router; My controllers/weeklyReportsController.js file: const weeklyReportsService = require('../services/weeklyReportsService'); const weeklyReportsController = async (req, res) => { try { const data = await weeklyReportsService; res.json({data}) console