From 3e0610f41304c33f12f9f09d1d1a9c371085d11d Mon Sep 17 00:00:00 2001 From: Ryan Wang Date: Tue, 8 Sep 2020 23:11:09 -0400 Subject: [PATCH] feature: add vis_args for passing visdom arguments --- project/datasets.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/project/datasets.py b/project/datasets.py index 49d0bf0e..3a2a5754 100644 --- a/project/datasets.py +++ b/project/datasets.py @@ -22,10 +22,10 @@ def make_pts(N): class Graph: - def __init__(self, vis=False): + def __init__(self, vis=False, vis_args={}): self.gifs = [] if vis: - self.vis = visdom.Visdom() + self.vis = visdom.Visdom(**vis_args) else: self.vis = None self.first = True @@ -74,8 +74,8 @@ def graph(self, outfile, model=None): class Simple(Graph): - def __init__(self, N, vis=False): - super().__init__(vis) + def __init__(self, N, vis=False, vis_args={}): + super().__init__(vis, vis_args) self.N = N self.X = make_pts(N) self.y = [] @@ -85,8 +85,8 @@ def __init__(self, N, vis=False): class Split(Graph): - def __init__(self, N, vis=False): - super().__init__(vis) + def __init__(self, N, vis=False, vis_args={}): + super().__init__(vis, vis_args) self.N = N self.X = make_pts(N) self.y = [] @@ -96,8 +96,8 @@ def __init__(self, N, vis=False): class Xor(Graph): - def __init__(self, N, vis=False): - super().__init__(vis) + def __init__(self, N, vis=False, vis_args={}): + super().__init__(vis, vis_args) self.N = N self.X = make_pts(N) self.y = []