logistic_model = LogisticRegression(C=1e5)
logistic_model.fit(X, y)
x_range = np.linspace(X[:, 0].min(), X[:, 0].max(), 20)
y_range = np.linspace(X[:, 1].min(), X[:, 1].max(), 20)
XX, YY = np.meshgrid(x_range, y_range)
P = expit(XX* logistic_model.coef_[0][0] + YY * logistic_model.coef_[0][0] + logistic_model.intercept_)
with sns.plotting_context("talk"):
fig = plt.figure(figsize = (10, 7))
ax = plt.axes(projection ="3d")
# ZZ = np.ones_like(XX) * 0.5
# ax.plot_surface(XX, YY, ZZ, **kwargs, zorder=1)
ax.scatter(X[:,0], X[:,1], y, c=y, cmap='jet', zorder=10, edgecolor='k')
ax.plot_surface(XX, YY, P, alpha=1.0, cmap='viridis', edgecolor='k');