Creating your own dataset in tensorflow - python

I am faced with the task of classifying sound by spectrograms. I have a solution to this problem in one way (I will convert all audio recordings into spectrograms -> save them as pictures and train a neural network for this), but I want to go the simpler way, that is, not save pictures, but immediately convert audio files into tensors, but there is a problem, I can't find any useful information on how to create my data set from tensors in TensorFlow. I will give an example of such code on Pytorch.
class SoundDataset(Dataset):
def __init__(self, file_names, labels):
self.file_names = file_names
self.labels = labels
def __getitem__(self,index):
#format the file path and load the file
path = self.file_names[index]
scale, sr = librosa.load(path)
filter_banks = librosa.filters.mel(n_fft=2048, sr=22050, n_mels=10)
mel_spectrogram = librosa.feature.melspectrogram(scale, sr=sr, n_fft=2048, hop_length=512, n_mels=32)
log_mel_spectrogram = librosa.power_to_db(mel_spectrogram)
trch = torch.from_numpy(log_mel_spectrogram)
if log_mel_spectrogram.shape !=(10,87):
delta = 87 - log_mel_spectrogram.shape[1]
trch = torch.nn.functional.pad(trch, (0,delta))
return trch,self.labels[index]
def __len__(self):
return len(self.file_names)
Here a class is being created that takes paths to audio recordings and converts them into tensors, and will pad zeros if the tensors do not fit the shapes. How can I create the same class for TensorFlow. Next is an example of code that creates tuples with file paths and their class and creates an object of the Sound Data set class and generates a dataset from these files accordingly. All this is written for Pytorch. Tell me how it can be implemented for TensorFlow.
path = '/content/drive/MyDrive/МДМА/audiodata/for-rerecorded/training/'
files = []
labels = []
lbl = '1 0'.split()
for lab in lbl:
if lab == '0':
c = 'fake'
else:
c ='real'
names = os.listdir(path+c)
for n in names:
pth = path+c+'/'+n
files.append(pth)
labels.append(int(lab))
train_dataset = SoundDataset(files, labels)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size = 20)

If you read the documentation there are code patterns.
This is not tested but if you load the index from another data structure which has mapped the files to the indexes then this code can help.
import librosa
import pathlib
import tensorflow as tf
DATASET_PATH = 'data/mini_speech_commands'
data_dir = pathlib.Path(DATASET_PATH)
if not data_dir.exists():
tf.keras.utils.get_file(
'mini_speech_commands.zip',
origin="http://storage.googleapis.com/download.tensorflow.org/data/mini_speech_commands.zip",
extract=True,
cache_dir='.', cache_subdir='data')
def load_audio(filename):
scale, sr = librosa.load(filename)
mel_spectrogram = librosa.feature.melspectrogram(scale, sr=sr, n_fft=2048, hop_length=512, n_mels=32)
log_mel_spectrogram = librosa.power_to_db(mel_spectrogram)
spectrogram_numpy = log_mel_spectrogram.numpy()
if log_mel_spectrogram.shape !=(10,87):
delta = 87 - log_mel_spectrogram.shape[1]
spectrogram_numpy = tf.pad(spectrogram_numpy, (0,delta))
return spectrogram_numpy #return index
read_audio = lambda x: tf.py_function(load_audio,
[x],
tf.float64)
filenames = tf.io.gfile.glob(str(data_dir) + '/*/*')
files_ds = tf.data.Dataset.from_tensor_slices(filenames)
waveform_ds = files_ds.map(
map_func=read_audio)
The code to pad is converted using TensorFlow directly and you have to test it.
Update : Another way using keras.utils.Sequence is shown in this thread

Related

Pytorch model weights change when put on GPU

I noticed a very strange behaviour regarding the 3D Resnet by Facebookresearch. Using their sample code from the website, I receive different results, when putting the model on GPU. While on cpu the correct class (archery) is predicted, the model fails to predict it on GPU. Can anyone replicate this and confirm that this is indeed the case? Does anyone know, why this is happening and how to prevent it? Following, you will find some code to quickly test it out:
import torch
import json
import urllib
from pytorchvideo.data.encoded_video import EncodedVideo
from torchvision.transforms import Compose, Lambda
from torchvision.transforms._transforms_video import (
CenterCropVideo,
NormalizeVideo,
)
from pytorchvideo.transforms import (
ApplyTransformToKey,
ShortSideScale,
UniformTemporalSubsample
)
def predict_archery(model, device):
json_url = "https://dl.fbaipublicfiles.com/pyslowfast/dataset/class_names/kinetics_classnames.json"
json_filename = "kinetics_classnames.json"
try:
urllib.URLopener().retrieve(json_url, json_filename)
except:
urllib.request.urlretrieve(json_url, json_filename)
with open(json_filename, "r") as f:
kinetics_classnames = json.load(f)
# Create an id to label name mapping
kinetics_id_to_classname = {}
for k, v in kinetics_classnames.items():
kinetics_id_to_classname[v] = str(k).replace('"', "")
side_size = 256
mean = [0.45, 0.45, 0.45]
std = [0.225, 0.225, 0.225]
crop_size = 256
num_frames = 8
sampling_rate = 8
frames_per_second = 30
# Note that this transform is specific to the slow_R50 model.
transform = ApplyTransformToKey(
key="video",
transform=Compose(
[
UniformTemporalSubsample(num_frames),
Lambda(lambda x: x / 255.0),
NormalizeVideo(mean, std),
ShortSideScale(
size=side_size
),
CenterCropVideo(crop_size=(crop_size, crop_size))
]
),
)
# The duration of the input clip is also specific to the model.
clip_duration = (num_frames * sampling_rate) / frames_per_second
url_link = "https://dl.fbaipublicfiles.com/pytorchvideo/projects/archery.mp4"
video_path = 'archery.mp4'
try:
urllib.URLopener().retrieve(url_link, video_path)
except:
urllib.request.urlretrieve(url_link, video_path)
# Select the duration of the clip to load by specifying the start and end duration
# The start_sec should correspond to where the action occurs in the video
start_sec = 0
end_sec = start_sec + clip_duration
# Initialize an EncodedVideo helper class and load the video
video = EncodedVideo.from_path(video_path)
# Load the desired clip
video_data = video.get_clip(start_sec=start_sec, end_sec=end_sec)
# Apply a transform to normalize the video input
video_data = transform(video_data)
# Move the inputs to the desired device
inputs = video_data["video"]
inputs = inputs.to(device)
# Pass the input clip through the model
preds = model(inputs[None, ...])
# Get the predicted classes
post_act = torch.nn.Softmax(dim=1)
preds = post_act(preds)
pred_classes = preds.topk(k=5).indices[0]
# Map the predicted classes to the label names
pred_class_names = [kinetics_id_to_classname[int(i)] for i in pred_classes]
print("Top 5 predicted labels: %s" % ", ".join(pred_class_names))
if __name__ == '__main__':
# Choose device
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device = torch.device("cpu")
# Choose the `slow_r50` model
model = torch.hub.load('facebookresearch/pytorchvideo', 'slow_r50', pretrained=True).to(device)
model = model.eval()
predict_archery(model, device)
Results on cpu:
Top 5 predicted labels: archery, throwing axe, playing paintball,
stretching arm, riding or walking with horse
Results on GPU:
Top 5 predicted labels: flying kite, air drumming, beatboxing,
smoking, reading book
Edit:
Apparently, this issue cannot be reproduced on google colab. I therefore assume that the issue is related to the specific hardware / cuda version. I am using a NVIDIA TITAN Xp and cuda version 11.4.

Training stuck at Epoch 3 PyTorch

I am training a custom Encoder-Decoder network but the training gets stuck at Epoch 3. Nothing happens for about 2 hours. I will share the Dataset class and the DataLoader object. The version if CUDA and GPU can be seen in the pic below.
Training stuck here:
nvidia-smi output looks like this:
The __getitem__ method of the dataset class looks like this:
def __init__(self,
images_dir,
annots_dir,
train=True,
img_size=(512, 1536),
stride=4,
model='custom',
transforms=None):
"""
:param root: dataset directory
:param filenames: filenames inside the root directory
:param labels: Object Detection Labels
super(CustomDataset).__init__()
self.images_dir = images_dir
self.annots_dir = annots_dir
self.train = train
self.image_size = img_size
self.stride = stride
self.transforms = transforms
self.model = model
# Load the image and annotation files from the dataset
# self.image_files, self.annot_files = self._load_image_and_annot_files()
self.image_files = [os.path.join(self.images_dir, idx) for idx in os.listdir(self.images_dir)]
self.annot_files = [os.path.join(self.annots_dir, idx) for idx in os.listdir(self.annots_dir)]
def __getitem__(self, index):
"""
:param index: index...0 to N
:return: tensor_image and tensor_label
"""
# Image filename from _load_image_files()
# Load Image with _read_matrix() and label
curr_image_filename = self.image_files[index]
curr_annot_filename = self.annot_files[index]
# curr_image_filename = self.image_files[index]
# curr_annot_filename = self.annot_files[index]
np_image = self._read_matrix(raw_img=curr_image_filename)
np_image_normalized = np.squeeze(self._normalize_raw_img(np_image))
# label = self.labels[index]
boxes, classes, depths, tgts = self._load_annotations(curr_annot_filename)
# Normalize bounding boxes: range [0, 1]
targets_normalized = self._normalize_bbox(np_image_normalized, tgts)
# image and the corresponding label should be a tensor
torch_image = torch.from_numpy(np_image).reshape(1, 512, 1536).float() # dtype: torch.float64
torch_boxes = torch.from_numpy(boxes).type(torch.FloatTensor)
torch_depths = torch.from_numpy(depths)
if self.model == 'fasterrcnn':
# For FasterRCNN: As COCO format
area = (torch_boxes[:, 3] - torch_boxes[:, 1]) * (torch_boxes[:, 2] - torch_boxes[:, 0])
iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64)
image_id = torch.Tensor([index])
torch_classes = torch.from_numpy(classes)
target = {'boxes': torch_boxes, 'labels': torch_classes.long(),
'area': area, 'iscrowd': iscrowd, 'image_id': image_id}
return torch_image, target
elif self.model == 'custom':
if self.train:
if self.transforms:
try:
tr = self.transforms()
transform_image, transform_boxes, labels = tr.__call__(np_image, tgts, tgts[:, :4], tgts[:, 4:])
transform_targets = np.hstack((np.array(transform_boxes), labels))
gt_tensor = gt_creator(img_size=self.image_size,
stride=self.stride,
num_classes=8,
label_lists=transform_targets)
return torch.from_numpy(transform_image).float(), gt_tensor
except IndexError:
pass
else:
gt_tensor = gt_creator(img_size=self.image_size,
stride=self.stride,
num_classes=8,
label_lists=targets_normalized)
return torch_image, gt_tensor
else:
return torch_image, targets_normalized
And in the train.py script the DataLoader object is:
train_loader = torch.utils.data.DataLoader(dataset=dataset,
shuffle=True,
batch_size=1,
num_workers=0,
collate_fn=detection_collate,
pin_memory=True)
Why does the training get stuck? Is there an issue with the __getitem__ method? Or the DataLoader?
Thank You.
This happens because torch doesnt restart your dataset, if your data runs out it stops and waits for more input so cycling has to be done manually.
I used something along the lines of
from itertools import cycle
class Dataloader():
#init and whatever
self.__iter__():
return cycle(get_sample()) # get_sample is your current getitem

Skip image during tensorflow input pipeline

I have a Tensorflow input pipeline that reads in two png files (example, label) from disk. I want to tell tensorflow to skip an example/label pair based on a value in the label. Anyone know how to do this?
Here is a simplified example of the input pipeline and with a comment where I want to do the filtering:
import tensorflow as tf
import glob2 as glob
def preprocess_images(impath, labpath):
image = tf.io.read_file(impath)
label = tf.io.read_file(labpath)
image = tf.image.decode_png(image, channels=3)
label = tf.image.decode_png(label, channels=1)
if tf.reduce_sum(label) == 0:
#skip the image and move on to the next, don't include this in the batch
else:
return (image, label)
im_files = glob.glob(impath + '*.png')
lab_files = glob.glob(labpath + '*.png')
files = (im_files, lab_files)
path = tf.data.Dataset.from_tensor_slices(files)
pair = path.map(preprocess_images)
ds = tf.data.Dataset.zip(pair)
ds = ds.batch(64)
The easiest way seems to be to use filter method on your tf.data.Dataset object.
Here I am going to load the label only and filter out the entries with a sum of 0:
def load_label_only(impath, labpath):
label = tf.io.read_file(labpath)
label = tf.image.decode_png(label, channels=1)
return impath, label
# Create the dataset as in your example:
im_files = glob.glob(impath + '*.png')
lab_files = glob.glob(labpath + '*.png')
files = (im_files, lab_files)
ds = tf.data.Dataset.from_tensor_slices(files)
ds = ds.map(load_label_only)
# Here, I am going to keep only non-zero labels:
filtered_ds = ds.filter(lambda image_path, label_map: tf.reduce_sum(label_map) != 0)
# Load the rest of the images...

Error preprocessing the input data when using Tensorflow Dataset API

I have images of [64,512,5] stored in *.npy files which I convert into *.tfrecords files.
I have verified that the reading of said records corresponds correctly with what is present in the *.npy files. However, when I perform some operation on the parser, like adding 1 to each pixel of the image, the result is not the expected one. The result should be 65*512*5 = 163840 but it is 163839.99980013957 (not always the same)
I have tried to perform different operations like tf.subtract, but the results are the same.
Could someone tell me what is wrong?
import re
import ast
import sys, select
import random as rn
from glob import glob
from tqdm import tqdm
from datetime import datetime
from configparser import SafeConfigParser
import numpy as np
import numpy.ma as ma
import scipy.misc
import os.path
from os import mkdir, stat
from os.path import exists, dirname, abspath
from os.path import join as dir_join
import tensorflow as tf
''' File hierarchy
'''
_code_dir = dirname(abspath(__file__))
_python_dir = dirname(_code_dir)
_model_dir = dirname(_python_dir)
_project_dir = dirname(_model_dir)
_ml_dir = dirname(_project_dir)
_srv_dir = dirname(_ml_dir)
_root_datasets_dir = dir_join(_srv_dir,'machine_learning','data_sets/ssd_prepared')
_config_dir = dir_join(_python_dir, 'config')
'''Data sets directories
'''
THIS_DATA_SET_DIR = 'Sph_50m' #WARNING: Global variable also used in helper.py
_data_dir = dir_join(_root_datasets_dir, THIS_DATA_SET_DIR)
_data_set_dir = dir_join(_data_dir,'ImageSet')
_data_npy_dir = dir_join(_data_dir,'data')
_data_tfRecord_dir = dir_join(_data_dir,'tfRecord')
''' Configuration parser
'''
cfg_parser = SafeConfigParser()
cfg_parser.read(dir_join(_config_dir,'cfg_model.ini'))
''' Private variables
'''
_batch_size = cfg_parser.getint(section='train', option='batch_size')
_max_epoch = cfg_parser.getint(section='train', option='max_epoch')
_standarize = cfg_parser.getboolean(section='train', option='standarize_input')
_input_shape = ast.literal_eval(cfg_parser.get(section='data_shape', option='input_shape'))
_label_channel = cfg_parser.getint(section='data_shape', option='label_channel')
_track_channel = cfg_parser.getint(section='data_shape', option='track_channel')
_mask_channel = cfg_parser.getint(section='data_shape', option='mask_channel')
_data_train = cfg_parser.get(section='data_set', option='data_train')
_data_val = cfg_parser.get(section='data_set', option='data_val')
_data_test = cfg_parser.get(section='data_set', option='data_test')
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value.reshape(-1)))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _floats_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value.reshape(-1)))
def numpy_to_TFRecord():
if not exists(_data_tfRecord_dir): mkdir(_data_tfRecord_dir)
for dataset in [_data_train, _data_val, _data_test]:
tfRecord_folder = dir_join(_data_tfRecord_dir, dataset)
if not exists(tfRecord_folder): mkdir(tfRecord_folder)
#Retrieve list of files
projections_dir=[]
file_ = open(dir_join(_data_set_dir, dataset+'.txt'), 'r')
for x in file_.readlines():
file_nat = x.strip()+'.npy'
filename = dir_join(_data_npy_dir, file_nat)
assert exists(filename), "{} doesn't exist".format(filename)
projections_dir.append(filename)
file_.close()
totaltfRecordSize = 0
numFile = 0
for projection_dir in tqdm(projections_dir, ncols= 100, desc = 'TFRecord {}'.format(dataset)):
scanName = projection_dir.split('/')[-1].split('.')[0]
if totaltfRecordSize > 100*(10**6) or totaltfRecordSize == 0:
# address to save the TFRecords file
train_filename = dir_join(tfRecord_folder, \
str(numFile) + '_' + dataset +'.tfrecords')
# open the TFRecords file
writer = tf.python_io.TFRecordWriter(train_filename)
numFile += 1
totaltfRecordSize = 0
# Load the image
projection = np.load(projection_dir)
image = projection[:,:,:_label_channel]
label = projection[:,:,_label_channel].astype(int)
mask = projection[:,:,_mask_channel].astype(int)
track = projection[:,:,_track_channel].astype(int)
# Create a feature
feature = {'image': _floats_feature(image),
'label': _int64_feature(label),
'mask' : _int64_feature(mask),
'track': _int64_feature(track),
'scanName': _bytes_feature(tf.compat.as_bytes(scanName))}
# Create an example protocol buffer
example = tf.train.Example(features=tf.train.Features(feature=feature))
# Serialize to string and write on the file
writer.write(example.SerializeToString())
fileSize = stat(train_filename).st_size
totaltfRecordSize += fileSize
writer.close()
sys.stdout.flush()
def readTFRecord():
# Transforms a scalar string `example_proto` into a pair of a scalar string and
# a scalar integer, representing an image and its label, respectively.
image_dim = _input_shape[0] * _input_shape[1] * _label_channel
label_dim = _input_shape[0] * _input_shape[1]
mean = np.load(dir_join(_data_dir,'mean.npy'))
std = np.load(dir_join(_data_dir,'std.npy'))
mean_tf = tf.convert_to_tensor(mean, dtype=tf.float32, name='mean')
std_tf = tf.convert_to_tensor(std, dtype=tf.float32, name='std')
with tf.variable_scope('TFRecord'):
def _parse_function(example_proto):
with tf.variable_scope('parser'):
features = {'image': tf.FixedLenFeature([image_dim], tf.float32),
'label': tf.FixedLenFeature([label_dim], tf.int64),
'mask' : tf.FixedLenFeature([label_dim], tf.int64),
'track': tf.FixedLenFeature([label_dim], tf.int64),
'scanName': tf.FixedLenFeature([], tf.string)}
parsed_features = tf.parse_single_example(example_proto, features)
# Reshape image data into the original shape
image = tf.reshape(parsed_features['image'], [_input_shape[0], _input_shape[1], _label_channel], name='image')
label = tf.reshape(parsed_features['label'], _input_shape, name='lable_reshape')
mask = tf.reshape(parsed_features['mask'], _input_shape, name='mask_reshape')
track = tf.reshape(parsed_features['track'], _input_shape, name='track_reshape')
scanName = parsed_features['scanName']
image = image + tf.constant(1., dtype=tf.float32)
return image, label, mask, track, scanName
training_filenames = glob(dir_join(_data_tfRecord_dir, _data_train, '*.tfrecords'))
validation_filenames = glob(dir_join(_data_tfRecord_dir, _data_val, '*.tfrecords'))
filenames = tf.placeholder(tf.string, shape=[None], name='filenames')
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function, num_parallel_calls=20) # Parse the record into tensors.
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(_batch_size, drop_remainder=True)
dataset = dataset.prefetch(buffer_size=10)
iterator = dataset.make_initializable_iterator()
next = iterator.get_next()
sess = tf.Session()
while True:
sess.run(iterator.initializer, feed_dict={filenames: training_filenames})
try:
img, _, _, _, scanX = sess.run(next)
for i, scan in enumerate(scanX):
print(scan.decode("utf-8"))
projection = np.load(dir_join(_data_npy_dir, scan.decode("utf-8") + '.npy'))
imagenp = projection[:,:,:_label_channel]
if np.abs(np.sum(img[i,...] - imagenp)) > 0.:
print(np.sum(img[i,...] - imagenp))
except tf.errors.OutOfRangeError:
break
return training_filenames, validation_filenames, filenames, iterator
if __name__ == '__main__':
numpy_to_TFRecord()
readTFRecord()
The test I'm doing in the previous code is to convert the *.npy files to *.tfrecords. Then, I compare the *.trecords with the *.npy. The value should be 0 if both images were identical.
img, _, _, _, scanX = sess.run(next)
for i, scan in enumerate(scanX):
print(scan.decode("utf-8"))
projection = np.load(dir_join(_data_npy_dir, scan.decode("utf-8") + '.npy'))
imagenp = projection[:,:,:_label_channel]
print(np.sum(img[i,...] - imagenp))
If the data is not preprocessed, these images are the same, however, if we perform some kind of transformation, the results do not match what was expected. In this case we are adding 1 to each pixel of the image, so the total difference should be 64 * 512 * 5.
image = image + tf.constant(1., dtype=tf.float32)
I would like to solve this error, since so far I have not been able to obtain the results obtained by my neural network using feed_dict instead of Tensorflow Dataset API, and this is the only point where I can observe a difference in the input data.

Broadcasting a keras model with pyspark [duplicate]

I am using Caffe to do image classification, can I am using MAC OS X, Pyhton.
Right now I know how to classify a list of images using Caffe with Spark python, but if I want to make it faster, I want to use Spark.
Therefore, I tried to apply the image classification on each element of an RDD, the RDD created from a list of image_path. However, Spark does not allow me to do so.
Here is my code:
This is the code for image classification:
# display image name, class number, predicted label
def classify_image(image_path, transformer, net):
image = caffe.io.load_image(image_path)
transformed_image = transformer.preprocess('data', image)
net.blobs['data'].data[...] = transformed_image
output = net.forward()
output_prob = output['prob'][0]
pred = output_prob.argmax()
labels_file = caffe_root + 'data/ilsvrc12/synset_words.txt'
labels = np.loadtxt(labels_file, str, delimiter='\t')
lb = labels[pred]
image_name = image_path.split(images_folder_path)[1]
result_str = 'image: '+image_name+' prediction: '+str(pred)+' label: '+lb
return result_str
This this the code generates Caffe parameters and apply the classify_image method on each element of the RDD:
def main():
sys.path.insert(0, caffe_root + 'python')
caffe.set_mode_cpu()
model_def = caffe_root + 'models/bvlc_reference_caffenet/deploy.prototxt'
model_weights = caffe_root + 'models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel'
net = caffe.Net(model_def,
model_weights,
caffe.TEST)
mu = np.load(caffe_root + 'python/caffe/imagenet/ilsvrc_2012_mean.npy')
mu = mu.mean(1).mean(1)
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
transformer.set_transpose('data', (2,0,1))
transformer.set_mean('data', mu)
transformer.set_raw_scale('data', 255)
transformer.set_channel_swap('data', (2,1,0))
net.blobs['data'].reshape(50,
3,
227, 227)
image_list= []
for image_path in glob.glob(images_folder_path+'*.jpg'):
image_list.append(image_path)
images_rdd = sc.parallelize(image_list)
transformer_bc = sc.broadcast(transformer)
net_bc = sc.broadcast(net)
image_predictions = images_rdd.map(lambda image_path: classify_image(image_path, transformer_bc, net_bc))
print image_predictions
if __name__ == '__main__':
main()
As you can see, here I tried to broadcast the caffe parameters, transformer_bc = sc.broadcast(transformer), net_bc = sc.broadcast(net)
The error is:
RuntimeError: Pickling of "caffe._caffe.Net" instances is not enabled
Before I am doing the broadcast, the error was :
Driver stacktrace.... Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):....
So, do you know, is there any way I can classify images using Caffe and Spark but also take advantage of Spark?
When you work with complex, non-native objects initialization has to moved directly to the workers for example with singleton module:
net_builder.py:
import cafe
net = None
def build_net(*args, **kwargs):
... # Initialize net here
return net
def get_net(*args, **kwargs):
global net
if net is None:
net = build_net(*args, **kwargs)
return net
main.py:
import net_builder
sc.addPyFile("net_builder.py")
def classify_image(image_path, transformer, *args, **kwargs):
net = net_builder.get_net(*args, **kwargs)
It means you'll have to distribute all required files as well. It can be done either manually or using SparkFiles mechanism.
On a side note you should take a look at the SparkNet package.

Categories

Resources