Skip to main content

FID and custom feature extractor

I would like to use a custom feature extractor to calculate FID

according to https://lightning.ai/docs/torchmetrics/stable/image/frechet_inception_distance.html I can use nn.Module for feature

What is wrong with the following code?



import torch
_ = torch.manual_seed(123)
from torchmetrics.image.fid import FrechetInceptionDistance
from torchvision.models import inception_v3


net = inception_v3()
checkpoint = torch.load('checkpoint.pt')
net.load_state_dict(checkpoint['state_dict'])
net.eval()

fid = FrechetInceptionDistance(feature=net)
# generate two slightly overlapping image intensity distributions
imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
fid.update(imgs_dist1, real=True)
fid.update(imgs_dist2, real=False)
result = fid.compute()

print(result)

Traceback (most recent call last):
  File "foo.py", line 12, in <module>
    fid = FrechetInceptionDistance(feature=net)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Lib/site-packages/torchmetrics/image/fid.py", line 304, in __init__
    num_features = self.inception(dummy_image).shape[-1]
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Lib/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Lib/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Lib/site-packages/torchvision/models/inception.py", line 166, in forward
    x, aux = self._forward(x)
             ^^^^^^^^^^^^^^^^
  File "/Lib/site-packages/torchvision/models/inception.py", line 105, in _forward
    x = self.Conv2d_1a_3x3(x)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/Lib/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Lib/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Lib/site-packages/torchvision/models/inception.py", line 405, in forward
    x = self.conv(x)
        ^^^^^^^^^^^^
  File "/Lib/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Lib/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Lib/site-packages/torch/nn/modules/conv.py", line 460, in forward
    return self._conv_forward(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Lib/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: expected scalar type Byte but found Float

Process finished with exit code 1



source https://stackoverflow.com/questions/77675693/fid-and-custom-feature-extractor

Comments

Popular posts from this blog

Prop `className` did not match in next js app

I have written a sample code ( Github Link here ). this is a simple next js app, but giving me error when I refresh the page. This seems to be the common problem and I tried the fix provided in the internet but does not seem to fix my issue. The error is Warning: Prop className did not match. Server: "MuiBox-root MuiBox-root-1" Client: "MuiBox-root MuiBox-root-2". Did changes for _document.js, modified _app.js as mentioned in official website and solutions in stackoverflow. but nothing seems to work. Could someone take a look and help me whats wrong with the code? Via Active questions tagged javascript - Stack Overflow https://ift.tt/2FdjaAW

How to show number of registered users in Laravel based on usertype?

i'm trying to display data from the database in the admin dashboard i used this: <?php use Illuminate\Support\Facades\DB; $users = DB::table('users')->count(); echo $users; ?> and i have successfully get the correct data from the database but what if i want to display a specific data for example in this user table there is "usertype" that specify if the user is normal user or admin i want to user the same code above but to display a specific usertype i tried this: <?php use Illuminate\Support\Facades\DB; $users = DB::table('users')->count()->WHERE usertype =admin; echo $users; ?> but it didn't work, what am i doing wrong? source https://stackoverflow.com/questions/68199726/how-to-show-number-of-registered-users-in-laravel-based-on-usertype

Why is my reports service not connecting?

I am trying to pull some data from a Postgres database using Node.js and node-postures but I can't figure out why my service isn't connecting. my routes/index.js file: const express = require('express'); const router = express.Router(); const ordersCountController = require('../controllers/ordersCountController'); const ordersController = require('../controllers/ordersController'); const weeklyReportsController = require('../controllers/weeklyReportsController'); router.get('/orders_count', ordersCountController); router.get('/orders', ordersController); router.get('/weekly_reports', weeklyReportsController); module.exports = router; My controllers/weeklyReportsController.js file: const weeklyReportsService = require('../services/weeklyReportsService'); const weeklyReportsController = async (req, res) => { try { const data = await weeklyReportsService; res.json({data}) console