import { extend, ReactThreeFiber } from "@react-three/fiber";
import glsl from "babel-plugin-glsl/macro";
import { shaderMaterial } from "@react-three/drei";
import { Color, DoubleSide, ShaderMaterial, CustomBlending, AddEquation, SrcAlphaFactor, OneMinusSrcAlphaFactor, OneFactor } from "three";

const PlantMaterial = shaderMaterial(
  {
    time: 0,
    saturation: 0,
    amplitude: 0.2,
    speed: 0.1,
    stage: 0,
    mapFolliage: null,
    mapBloom: null,
    mapWinter: null,
    mapLUT: null,
    side: DoubleSide,
    pollinatorView: false,
    transparent: true,
    depthWrite: false,
    usePng:false

  },
  // vertex shader
  glsl`
    varying vec2 vUv;
    void main() {
      vUv = uv;
      gl_Position = projectionMatrix * modelViewMatrix * vec4(position, 1.0);
    }
  `,
  // fragment shader
  glsl`
    #define LUT_FLIP_Y 1
    #pragma glslify: transform = require(glsl-lut)

    #ifdef GL_ES
    precision highp float;
    #endif

    uniform float time;
    uniform float saturation;
    uniform float amplitude;
    uniform float speed;

    uniform bool usePng;

    uniform bool pollinatorView;

    uniform float stage;

    uniform sampler2D mapFolliage;
    uniform sampler2D mapBloom;
    uniform sampler2D mapWinter;

    uniform sampler2D mapLUT;

    varying vec2 vUv;

    float sobel(sampler2D tex, vec2 uv) {
      vec2 size = vec2(1.0) / vec2(textureSize(tex, 0));  // Size of one pixel
      float Gx = 
          texture2D(tex, uv + vec2(-1.0, -1.0) * size).a * -1.0 +
          texture2D(tex, uv + vec2(0.0, -1.0) * size).a * -2.0 +
          texture2D(tex, uv + vec2(1.0, -1.0) * size).a * -1.0 +
          texture2D(tex, uv + vec2(-1.0, 1.0) * size).a * 1.0 +
          texture2D(tex, uv + vec2(0.0, 1.0) * size).a * 2.0 +
          texture2D(tex, uv + vec2(1.0, 1.0) * size).a * 1.0;
      float Gy = 
          texture2D(tex, uv + vec2(-1.0, -1.0) * size).a * -1.0 +
          texture2D(tex, uv + vec2(-1.0, 0.0) * size).a * -2.0 +
          texture2D(tex, uv + vec2(-1.0, 1.0) * size).a * -1.0 +
          texture2D(tex, uv + vec2(1.0, -1.0) * size).a * 1.0 +
          texture2D(tex, uv + vec2(1.0, 0.0) * size).a * 2.0 +
          texture2D(tex, uv + vec2(1.0, 1.0) * size).a * 1.0;
      return sqrt(Gx*Gx + Gy*Gy);
    }
    

    vec3 Desaturate(vec3 color,float Desaturation)
    {
      vec3 grayXfer=vec3(.3,.59,.11);
      vec3 gray=vec3(dot(grayXfer,color));
      return mix(color,gray,Desaturation);
    }

    vec4 layer(vec4 foreground, vec4 background) {
      return foreground * foreground.a + background * (1.0 - foreground.a);
    }

    vec3 rgb2hsb( in vec3 c ) {
      vec4 K = vec4(0.0, -1.0 / 3.0, 2.0 / 3.0, -1.0);
      vec4 p = mix(vec4(c.bg, K.wz),
                  vec4(c.gb, K.xy),
                  step(c.b, c.g));
      vec4 q = mix(vec4(p.xyw, c.r),
                  vec4(c.r, p.yzx),
                  step(p.x, c.r));
      float d = q.x - min(q.w, q.y);
      float e = 1.0e-10;
      return vec3(abs(q.z + (q.w - q.y) / (6.0 * d + e)),
                  d / (q.x + e),
                  q.x);
    }

    
    void useWhiteAsTransparent(inout vec4 c, float a) {
      vec3 hsb = rgb2hsb(c.rgb);
      c.a = a;
      
      float threshold = 0.1;
    
      if (hsb.y < threshold && hsb.z > (1.0 - threshold)) {
        float sum = hsb.y + 1.0 - hsb.z;
        float luminance = sum / 2.0;
        //c.a = 0.; 
        c.a = smoothstep(1.0 - threshold, 1.0, luminance);

      }

    }

    

    
    float sobelOutlineAlpha(sampler2D tex, vec2 uv, float edgeThreshold) {
      float alpha = texture2D(tex, uv).a;
    
      vec2 texSize = vec2(1.0); // Assuming normalized tex coordinates
      vec2 texel = 1.0 / texSize;
    
      // Sobel operator kernels
      mat3 sobelX = mat3(-1, 0, 1, -2, 0, 2, -1, 0, 1);
      mat3 sobelY = mat3(-1, -2, -1, 0, 0, 0, 1, 2, 1);
    
      float gradientX = 0.0;
      float gradientY = 0.0;
    
      for (int i = -1; i <= 1; i++) {
        for (int j = -1; j <= 1; j++) {
          vec2 offset = vec2(float(i), float(j)) * texel;
          float neighborAlpha = texture2D(tex, uv + offset).a;
          gradientX += neighborAlpha * sobelX[i + 1][j + 1];
          gradientY += neighborAlpha * sobelY[i + 1][j + 1];
        }
      }
    
      float gradient = length(vec2(gradientX, gradientY));
      float edgeAlpha = smoothstep(edgeThreshold, edgeThreshold * 1.2, gradient);
      return mix(alpha, edgeAlpha, edgeAlpha);
    }



    void main(void)
    {
      // sway texure from left to right
      float t1=sin(time*speed)*.5+.5;
      float a=mix(-amplitude,amplitude,t1);
      vec2 vUVr=vec2(mix(-amplitude,1.0+amplitude,vUv.x)+pow(vUv.y,2.)*a,vUv.y);

      float stage_mod = mod(stage, 3.0);
  
      float cross_fade_spread = 1.1;
      float power = 2.0;

      float thickness = 0.00001;
      // folliage
      vec4 Ca=texture2D(mapFolliage,vUVr);
      float alpha_a = 1.0-smoothstep(0.0,1.0,pow((stage_mod-0.0)*cross_fade_spread,power));
      Ca.a *= alpha_a;
      if(!usePng) {
        useWhiteAsTransparent(Ca, alpha_a);
      } else {
        Ca.a *= alpha_a;
      }

      // bloom
      vec4 Cb=texture2D(mapBloom,vUVr);
      float alpha_b = 1.0-smoothstep(0.0,1.0,pow((stage_mod-1.0)*cross_fade_spread,power));
      Cb.a *= alpha_b;
      if(!usePng) {
        useWhiteAsTransparent(Cb, alpha_b);
      } else {
        Cb.a *= alpha_b;
      }

      // winter
      vec4 Cc=texture2D(mapWinter,vUVr);
      float alpha_c = 1.0-smoothstep(0.0,1.0,pow((stage_mod-2.0)*cross_fade_spread,power));
      if(!usePng) {
        useWhiteAsTransparent(Cc, alpha_c);
      } else {
        Cc.a *= alpha_c;
      }

      // folliage again 
      vec4 Ca2=texture2D(mapFolliage,vUVr);
      float alpha_a2 = 1.0-smoothstep(0.0,1.0,pow((stage_mod-3.0)*cross_fade_spread,power));
     
      if(!usePng) {
        useWhiteAsTransparent(Ca2, alpha_a2);
      } else {
        Ca2.a *= alpha_a2;
      }
      
      vec4 Cf = layer(Ca2,layer(Cc,layer(Cb,Ca)));
   
      if(!usePng) {
        if (Cf.a < 0.5) discard;
      }

      if (pollinatorView) Cf = transform(Cf, mapLUT);
      
      gl_FragColor = Cf;
    }
  `
);

extend({ PlantMaterial });

declare global {
  namespace JSX {
    interface IntrinsicElements {
      plantMaterial: any;
    }
  }
}
