Skip to main content

Tensorflow simple cumulative sum of product rnn cell

Hello i am trying to build a tensorflow model that calculates a cumulative sum of products of two of the input features, ie predicting on only (1,2) should return 2, and then predicting on (2,2) should give 6=(1 * 2) + (2 * 2)

model.predict([1,2])
>>> 2

model.predict([2,2])
>>> 6

model.reset_states()
model.predict([2,2])
>>> 4

i have tried the following:

import numpy as np
import tensorflow as tf


class MinimalRNNCell(tf.keras.layers.Layer):

    def __init__(self, units, **kwargs):
        self.states = np.array([0])
        self.state = np.array([0])
        self.units = units
        self.state_size = units
        super(MinimalRNNCell, self).__init__(**kwargs)

    def call(self, inputs, states):
        prev_output = states[0]
        output = tf.math.add(prev_output,inputs)
        
        return output, [output]
    

# Define model
#input
inp = tf.keras.layers.Input(shape=(2,))
#split input
x1,x2 = tf.split(inp, num_or_size_splits=2, axis=1)
#calculate product
product = tf.math.multiply(x1,x2)
#reshape product
time_product = tf.keras.layers.Reshape((1,1))(product)
#Define memory cell and layer
memory_product = MinimalRNNCell(units=1)
layer_product = tf.keras.layers.RNN(memory_product)
#calculate cumulative product
cumulative_product = layer_product(time_product)

output = cumulative_product

model = tf.keras.models.Model(inp, output)


if __name__=="__main__":
    x = np.array([
        [1, 2],
        [2, 2]
    ])

    model.compile()
    y = model.predict(x)
    print()
    print("outptut: ", y)
>>> [[2],
     [4]]

I feel like something like a cumulative sum is easily implementet using rnn or lstm cells but it does not work how i expect it to



source https://stackoverflow.com/questions/76348599/tensorflow-simple-cumulative-sum-of-product-rnn-cell

Comments

Popular posts from this blog

Prop `className` did not match in next js app

I have written a sample code ( Github Link here ). this is a simple next js app, but giving me error when I refresh the page. This seems to be the common problem and I tried the fix provided in the internet but does not seem to fix my issue. The error is Warning: Prop className did not match. Server: "MuiBox-root MuiBox-root-1" Client: "MuiBox-root MuiBox-root-2". Did changes for _document.js, modified _app.js as mentioned in official website and solutions in stackoverflow. but nothing seems to work. Could someone take a look and help me whats wrong with the code? Via Active questions tagged javascript - Stack Overflow https://ift.tt/2FdjaAW

How to show number of registered users in Laravel based on usertype?

i'm trying to display data from the database in the admin dashboard i used this: <?php use Illuminate\Support\Facades\DB; $users = DB::table('users')->count(); echo $users; ?> and i have successfully get the correct data from the database but what if i want to display a specific data for example in this user table there is "usertype" that specify if the user is normal user or admin i want to user the same code above but to display a specific usertype i tried this: <?php use Illuminate\Support\Facades\DB; $users = DB::table('users')->count()->WHERE usertype =admin; echo $users; ?> but it didn't work, what am i doing wrong? source https://stackoverflow.com/questions/68199726/how-to-show-number-of-registered-users-in-laravel-based-on-usertype

Why is my reports service not connecting?

I am trying to pull some data from a Postgres database using Node.js and node-postures but I can't figure out why my service isn't connecting. my routes/index.js file: const express = require('express'); const router = express.Router(); const ordersCountController = require('../controllers/ordersCountController'); const ordersController = require('../controllers/ordersController'); const weeklyReportsController = require('../controllers/weeklyReportsController'); router.get('/orders_count', ordersCountController); router.get('/orders', ordersController); router.get('/weekly_reports', weeklyReportsController); module.exports = router; My controllers/weeklyReportsController.js file: const weeklyReportsService = require('../services/weeklyReportsService'); const weeklyReportsController = async (req, res) => { try { const data = await weeklyReportsService; res.json({data}) console